[
  {
    "path": ".clauderules",
    "content": "# Claude Code Rules - Follow Every Rule Exactly\n\nYou must prioritize straightforward code semantics, well-named types, clear function signatures, and robust, carefully-chosen abstractions. Think about how your decisions might impact these aspects of code quality before proposing any changes.\n\nYou have access to all modern Python features from Python 3.13, 3.12, 3.11...\n\n**When you're done making changes, remove any redundant comments; remaining comments should only apply to complex code segments, adding relevant context.**\n\n## 1. Code Discipline\n\n* Eliminate superfluous `try`/`catch` and `if` branches through strict typing and static analysis.\n* Use pure functions unless you must mutate fixed state—then wrap that state in a class.\n* Every function is **referentially transparent**: same inputs ⇒ same outputs, no hidden state, no unintended I/O.\n* Put side-effects in injectable \"effect handlers\"; keep core logic pure.\n\n## 2. Naming\n\n* Choose descriptive, non-abbreviated names—no 3-letter acronyms or non-standard contractions.\n* Anyone reading a function's type signature alone should grasp its purpose without extra context.\n\n## 3. Typing\n\n* Maintain **strict, exhaustive** typing; never bypass the type-checker.\n* Default to `Literal[...]` when an enum-like set is needed.\n* Prefer built-in types; when two values share structure but differ in meaning, enforce separation:\n  * Use `typing.NewType` for primitives (zero runtime cost).\n  * For serializable objects, add a `type: str` field that states the object's identity.\n\n## 4. Pydantic\n\n* Read, respect, and rely on Pydantic documentation.\n* Centralize a common `ConfigDict` with `frozen=True` and `strict=True` (or stricter) and reuse it everywhere.\n* For hierarchies of `BaseModel` variants, declare a discriminated union with `typing.Annotated[Base, Field(discriminator='variant')]`; publish a single `TypeAdapter[Base]` so all variants share one strict validator.\n\n## 5. IDs & UUIDs\n\n* Subclass Pydantic's `UUID4` for custom ID types.\n* Generate fresh IDs with `uuid.uuid4()`.\n* Create idempotency keys by hashing *persisted* state plus a **function-specific salt** to avoid collisions after crashes.\n\n## 6. Error Handling\n\n* Catch an exception **only** where you can handle or transform it meaningfully.\n* State in the docstring **where** each exception is expected to be handled and **why**.\n\n## 7. Dependencies\n\n* Introduce new external dependencies only after approval.\n* Request only libraries common in production environments.\n\n## 8. Use of `@final` & Freezing\n\n* Mark classes, methods, and variables as `@final` or otherwise immutable wherever applicable.\n\n## 9. Repository Workflow\n\nIf you spot a rule violation within code that you've not been asked to work on directly, inform the user rather than patching it ad-hoc.\n\n---\n\n### One-Sentence Summary\n\nWrite strictly-typed, pure, self-describing Python that uses Pydantic, well-scoped side-effects, immutable state, approved dependencies, and explicit error handling."
  },
  {
    "path": ".cursorrules",
    "content": "# follow **every** rule exactly; report any violation instead of silently fixing it.\n\nYou must prioritize straightforward code semantics, well-named types, clear function signatures, and robust, carefully-chosen abstractions. Think about how your decisions might impact these aspects of code quality before proposing any changes.\n\nYou can use the advanced features of `typing`. You have access to all of the new features from Python 3.13, 3.12, 3.11...\n\n**When you're done making your changes, remove any redundant comments that you may have left; the comments that remain should only apply to complex segments of code, adding relevant context.**\n\n## 1. Code Discipline\n\n* Eliminate superfluous `try` / `catch` and `if` branches through strict typing and static analysis.\n* Use pure functions unless you must mutate fixed state—then wrap that state in a class.\n* Every function is **referentially transparent**: same inputs ⇒ same outputs, no hidden state, no unintended I/O.\n* Put side-effects in injectable “effect handlers”; keep core logic pure.\n\n## 2. Naming\n\n* Choose descriptive, non-abbreviated names—no 3-letter acronyms or non-standard contractions.\n* Anyone reading a function’s type signature alone should grasp its purpose without extra context.\n\n## 3. Typing\n\n* Maintain **strict, exhaustive** typing; never bypass the type-checker.\n* Default to `Literal[...]` when an enum-like set is needed.\n* Prefer built-in types; when two values share structure but differ in meaning, enforce separation:\n  * Use `typing.NewType` for primitives (zero runtime cost).\n  * For serialisable objects, add a `type: str` field that states the object’s identity.\n\n## 4. Pydantic\n\n* Read, respect, and rely on Pydantic docs.\n* Centralise a common `ConfigDict` with `frozen=True` and `strict=True` (or stricter) and reuse it everywhere.\n* For hierarchies of `BaseModel` variants, declare a discriminated union with `typing.Annotated[Base, Field(discriminator='variant')]`; publish a single `TypeAdapter[Base]` so all variants share one strict validator.\n\n## 5. IDs & UUIDs\n\n* Subclass Pydantic’s `UUID4` for custom ID types.\n* Generate fresh IDs with `uuid.uuid4()`.\n* Create idempotency keys by hashing *persisted* state plus a **function-specific salt** to avoid collisions after crashes.\n\n## 6. Error Handling\n\n* Catch an exception **only** where you can handle or transform it meaningfully.\n* State in the docstring **where** each exception is expected to be handled and **why**.\n\n## 7. Dependencies\n\n* Introduce new external dependencies only after approval.\n* Request only libraries common in production environments.\n\n## 8. Use of `@final` & Freezing\n\n* Mark classes, methods, and variables as `@final` or otherwise immutable wherever applicable.\n\n## 9. Repository Workflow\n\nIf you spot a rule violation within code that you've not been asked to work on directly, inform the user rather than patching it ad-hoc.\n\n\n---\n\n### One-Sentence Summary\n\nWrite strictly-typed, pure, self-describing Python that uses Pydantic, well-scoped side-effects, immutable state, approved dependencies, and explicit error handling\n"
  },
  {
    "path": ".envrc",
    "content": "use flake\n"
  },
  {
    "path": ".githooks/post-checkout",
    "content": "#!/bin/sh\ncommand -v git-lfs >/dev/null 2>&1 || { printf >&2 \"\\n%s\\n\\n\" \"This repository is configured for Git LFS but 'git-lfs' was not found on your path. If you no longer wish to use Git LFS, remove this hook by deleting the 'post-checkout' file in the hooks directory (set by 'core.hookspath'; usually '.git/hooks').\"; exit 2; }\ngit lfs post-checkout \"$@\"\n"
  },
  {
    "path": ".githooks/post-commit",
    "content": "#!/bin/sh\ncommand -v git-lfs >/dev/null 2>&1 || { printf >&2 \"\\n%s\\n\\n\" \"This repository is configured for Git LFS but 'git-lfs' was not found on your path. If you no longer wish to use Git LFS, remove this hook by deleting the 'post-commit' file in the hooks directory (set by 'core.hookspath'; usually '.git/hooks').\"; exit 2; }\ngit lfs post-commit \"$@\"\n"
  },
  {
    "path": ".githooks/post-merge",
    "content": "#!/bin/sh\ncommand -v git-lfs >/dev/null 2>&1 || { printf >&2 \"\\n%s\\n\\n\" \"This repository is configured for Git LFS but 'git-lfs' was not found on your path. If you no longer wish to use Git LFS, remove this hook by deleting the 'post-merge' file in the hooks directory (set by 'core.hookspath'; usually '.git/hooks').\"; exit 2; }\ngit lfs post-merge \"$@\"\n"
  },
  {
    "path": ".githooks/pre-push",
    "content": "#!/bin/sh\ncommand -v git-lfs >/dev/null 2>&1 || { printf >&2 \"\\n%s\\n\\n\" \"This repository is configured for Git LFS but 'git-lfs' was not found on your path. If you no longer wish to use Git LFS, remove this hook by deleting the 'pre-push' file in the hooks directory (set by 'core.hookspath'; usually '.git/hooks').\"; exit 2; }\ngit lfs pre-push \"$@\"\n"
  },
  {
    "path": ".github/CODEOWNERS",
    "content": "*   @ToxicPine\n*   @AlexCheema\n*   @GeluVrabie\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug Report\nabout: Create a report to help us improve\ntitle: '[BUG] '\nlabels: bug\nassignees: ''\n---\n\n## Describe the bug\n\nA clear and concise description of what the bug is.\n\n## To Reproduce\n\nSteps to reproduce the behavior:\n1.\n2.\n3.\n\n## Expected behavior\n\nA clear and concise description of what you expected to happen.\n\n## Actual behavior\n\nA clear and concise description of what actually happened.\n\n## Environment\n\n- macOS Version:\n- EXO Version:\n- Hardware:\n  - Device 1: (e.g., MacBook Pro M1 Max, 32GB RAM)\n  - Device 2: (e.g., Mac Mini M2, 16GB RAM)\n  - Additional devices:\n- Interconnection:\n  - (e.g., Thunderbolt 4 cable between Device 1 and 2)\n  - (e.g., WiFi 6 for Device 3)\n  - (e.g., 10GbE Ethernet between all devices)\n\n## Additional context\n\nAdd any other context about the problem here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature Request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: enhancement\nassignees: ''\n---\n\n<!-- Please use a clear, descriptive title above -->\n\nDescribe what you'd like to see added to EXO.\n"
  },
  {
    "path": ".github/actions/conditional-commit/action.yml",
    "content": "name: Commit if changed\ndescription: \"Create a commit when the working tree is dirty\"\n\ninputs:\n  message:\n    description: \"Commit message\"\n    required: true\n\nruns:\n  using: composite\n  steps:\n    - name: Commit changed files\n      shell: bash\n      run: |\n        git diff --quiet && exit 0\n        git commit -am \"${{ inputs.message }}\"\n"
  },
  {
    "path": ".github/actions/format/action.yml",
    "content": "name: Format Code\n\ndescription: \"Run code formatter\"\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Format code\n      run: nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just fmt\n      shell: bash\n"
  },
  {
    "path": ".github/actions/lint/action.yml",
    "content": "name: Lint Code\n\ndescription: \"Run code linter\"\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Lint code\n      run: nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just lint\n      shell: bash\n"
  },
  {
    "path": ".github/actions/lint-check/action.yml",
    "content": "name: Lint Check\n\ndescription: \"Check for lint errors\"\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Lint check\n      run: nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just lint-check\n      shell: bash\n"
  },
  {
    "path": ".github/actions/regenerate-protobufs/action.yml",
    "content": "name: Regenerate Protobufs\n\ndescription: \"Regenerate protobuf files\"\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Regenerate protobufs\n      run: nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just regenerate-protobufs\n      shell: bash\n"
  },
  {
    "path": ".github/actions/setup-python-uv/action.yml",
    "content": "name: Setup Python & uv\n\ndescription: \"Regenerate Python environment from uv.lock\"\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Install uv\n      uses: astral-sh/setup-uv@v6\n      with:\n        enable-cache: true\n        cache-dependency-glob: uv.lock\n\n    - name: Install Python\n      run: uv python install\n      shell: bash\n\n    - name: Sync\n      run: uv sync --locked --all-extras --dev\n      shell: bash\n"
  },
  {
    "path": ".github/actions/unit-test/action.yml",
    "content": "name: Unit Test\n\ndescription: \"Run unit tests\"\n\nruns:\n  using: \"composite\"\n  steps:\n    - name: Run unit tests\n      run: |\n        nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync-clean\n        nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just test-fast\n      shell: bash\n"
  },
  {
    "path": ".github/actions/verify-clean/action.yml",
    "content": "name: Verify Clean Working Tree\n\ndescription: \"Fail the job if the previous step left the working tree dirty\"\n\ninputs:\n  step:\n    description: \"The name of the step that just executed\"\n    required: true\n\nruns:\n  using: composite\n  steps:\n    - name: Check git diff\n      shell: bash\n      run: |\n        if ! git diff --quiet; then\n          echo \"Error: ${{ inputs.step }} left working tree dirty.\" >&2\n          git --no-pager diff >&2\n          exit 1\n        fi "
  },
  {
    "path": ".github/pull_request_template.md",
    "content": "## Motivation\n\n<!-- Why is this change needed? What problem does it solve? -->\n<!-- If it fixes an open issue, please link to the issue here -->\n\n## Changes\n\n<!-- Describe what you changed in detail -->\n\n## Why It Works\n\n<!-- Explain why your approach solves the problem -->\n\n## Test Plan\n\n### Manual Testing\n<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB, connected via Thunderbolt 4) -->\n<!-- What you did: -->\n<!-- - -->\n\n### Automated Testing\n<!-- Describe changes to automated tests, or how existing tests cover this change -->\n<!-- - -->\n"
  },
  {
    "path": ".github/workflows/build-app.yml",
    "content": "name: Build EXO macOS DMG\n\n# Release workflow:\n# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown\n# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0\n# 3. This workflow builds, signs, and notarizes the DMG\n# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)\n# 5. DMG and appcast.xml are uploaded to S3\n# 6. The draft GitHub Release is published with the DMG attached\n#\n# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.\n# If no draft exists, a release is auto-created with generated notes.\n\non:\n  workflow_dispatch:\n  push:\n    tags:\n      - \"v*\"\n    branches:\n      - \"test-app\"\n\njobs:\n  build-macos-app:\n    runs-on: \"macos-26\"\n    permissions:\n      contents: write\n    env:\n      SPARKLE_VERSION: 2.9.0-beta.1\n      SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}\n      SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}\n      SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}\n      SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}\n      SPARKLE_S3_BUCKET: ${{ secrets.SPARKLE_S3_BUCKET }}\n      SPARKLE_S3_PREFIX: ${{ secrets.SPARKLE_S3_PREFIX }}\n      EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT: ${{ secrets.EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT }}\n      AWS_REGION: ${{ secrets.AWS_REGION }}\n      EXO_BUILD_NUMBER: ${{ github.run_number }}\n      EXO_LIBP2P_NAMESPACE: ${{ github.ref_name }}\n\n    steps:\n      # ============================================================\n      # Checkout and tag validation\n      # ============================================================\n\n      - name: Checkout\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0\n\n      - name: Derive release version from tag\n        run: |\n          if [[ \"$GITHUB_REF_NAME\" == \"test-app\" || \"${{ github.event_name }}\" == \"workflow_dispatch\" ]]; then\n            VERSION=\"0.0.0-alpha.0\"\n            echo \"IS_ALPHA=true\" >> $GITHUB_ENV\n          else\n            VERSION=\"${GITHUB_REF_NAME#v}\"\n            if [[ \"$VERSION\" == *-alpha* ]]; then\n              echo \"IS_ALPHA=true\" >> $GITHUB_ENV\n            else\n              echo \"IS_ALPHA=false\" >> $GITHUB_ENV\n            fi\n          fi\n          echo \"RELEASE_VERSION=$VERSION\" >> $GITHUB_ENV\n\n      - name: Compute build version from semver\n        run: |\n          VERSION=\"$RELEASE_VERSION\"\n          # Extract major.minor.patch (strip prerelease suffix)\n          BASE_VERSION=\"${VERSION%%-*}\"\n          MAJOR=$(echo \"$BASE_VERSION\" | cut -d. -f1)\n          MINOR=$(echo \"$BASE_VERSION\" | cut -d. -f2)\n          PATCH=$(echo \"$BASE_VERSION\" | cut -d. -f3)\n\n          # Extract prerelease number (e.g., \"alpha.2\" -> 2, or 999 for releases)\n          if [[ \"$VERSION\" == *-* ]]; then\n            PRERELEASE_PART=\"${VERSION#*-}\"\n            PRERELEASE_NUM=\"${PRERELEASE_PART##*.}\"\n            # Default to 0 if not a number\n            if ! [[ \"$PRERELEASE_NUM\" =~ ^[0-9]+$ ]]; then\n              PRERELEASE_NUM=0\n            fi\n          else\n            PRERELEASE_NUM=999\n          fi\n\n          # Compute: PRERELEASE + (1000 * PATCH) + (1_000_000 * MINOR) + (1_000_000_000 * MAJOR)\n          BUILD_VERSION=$((PRERELEASE_NUM + 1000 * PATCH + 1000000 * MINOR + 1000000000 * MAJOR))\n          echo \"EXO_BUILD_VERSION=$BUILD_VERSION\" >> $GITHUB_ENV\n          echo \"Computed build version: $BUILD_VERSION from $VERSION\"\n\n      - name: Ensure tag commit is on main\n        if: github.ref_type == 'tag'\n        run: |\n          git fetch origin main\n          # Alpha tags can be on any branch, production tags must be on main\n          if [[ \"$IS_ALPHA\" == \"true\" ]]; then\n            echo \"Alpha tag detected, skipping main branch check\"\n          elif ! git merge-base --is-ancestor origin/main HEAD; then\n            echo \"Production tag must point to a commit on main\"\n            exit 1\n          fi\n\n      - name: Fetch and validate release notes\n        if: github.ref_type == 'tag'\n        env:\n          GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n        run: |\n          # Find draft release by name using gh release list (more reliable with default token)\n          echo \"Looking for draft release named '$GITHUB_REF_NAME'...\"\n          DRAFT_EXISTS=$(gh release list --json name,isDraft --jq \".[] | select(.isDraft == true) | select(.name == \\\"$GITHUB_REF_NAME\\\") | .name\" 2>/dev/null || echo \"\")\n\n          if [[ -z \"$DRAFT_EXISTS\" ]]; then\n            if [[ \"$IS_ALPHA\" == \"true\" ]]; then\n              echo \"No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)\"\n              echo \"HAS_RELEASE_NOTES=false\" >> $GITHUB_ENV\n              exit 0\n            fi\n            echo \"ERROR: No draft release found for tag $GITHUB_REF_NAME\"\n            echo \"Please create a draft release with release notes before pushing the tag.\"\n            exit 1\n          fi\n\n          # Fetch full release details via API to get body and ID\n          echo \"Found draft release, fetching details...\"\n          RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq \".[] | select(.draft == true) | select(.name == \\\"$GITHUB_REF_NAME\\\")\" 2>/dev/null || echo \"\")\n\n          # Extract release notes\n          NOTES=$(echo \"$RELEASE_JSON\" | jq -r '.body // \"\"')\n          if [[ -z \"$NOTES\" || \"$NOTES\" == \"null\" ]]; then\n            if [[ \"$IS_ALPHA\" == \"true\" ]]; then\n              echo \"Draft release has no notes (optional for alphas)\"\n              echo \"HAS_RELEASE_NOTES=false\" >> $GITHUB_ENV\n              exit 0\n            fi\n            echo \"ERROR: Draft release exists but has no release notes\"\n            echo \"Please add release notes to the draft release before pushing the tag.\"\n            exit 1\n          fi\n\n          # Save release ID for later publishing\n          RELEASE_ID=$(echo \"$RELEASE_JSON\" | jq -r '.id')\n          echo \"DRAFT_RELEASE_ID=$RELEASE_ID\" >> $GITHUB_ENV\n          echo \"HAS_RELEASE_NOTES=true\" >> $GITHUB_ENV\n\n          echo \"Found draft release (ID: $RELEASE_ID), saving release notes...\"\n          echo \"$NOTES\" > /tmp/release_notes.md\n          echo \"RELEASE_NOTES_FILE=/tmp/release_notes.md\" >> $GITHUB_ENV\n\n      # ============================================================\n      # Install dependencies\n      # ============================================================\n\n      - name: Select Xcode 26.2\n        run: |\n          sudo xcode-select -s /Applications/Xcode_26.2.app\n          if ! xcrun -f metal >/dev/null 2>&1; then\n            echo \"Metal toolchain is not installed.\"\n            exit 1\n          fi\n\n      - name: Install Homebrew packages\n        run: brew install just awscli macmon\n\n      - name: Install UV\n        uses: astral-sh/setup-uv@v6\n        with:\n          enable-cache: true\n          cache-dependency-glob: uv.lock\n\n      - name: Setup Python\n        run: |\n          uv python install\n          uv sync --locked\n\n      - name: Install Nix\n        uses: cachix/install-nix-action@v31\n        with:\n          nix_path: nixpkgs=channel:nixos-unstable\n\n      - name: Configure Cachix\n        uses: cachix/cachix-action@v14\n        with:\n          name: exo\n          authToken: \"${{ secrets.CACHIX_AUTH_TOKEN }}\"\n\n      - name: Build dashboard\n        run: |\n          DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)\n          mkdir -p dashboard/build\n          cp -r \"$DASHBOARD_OUT\"/* dashboard/build/\n\n      - name: Install Sparkle CLI\n        run: |\n          CLI_URL=\"${SPARKLE_CLI_URL:-https://github.com/sparkle-project/Sparkle/releases/download/${SPARKLE_VERSION}/Sparkle-${SPARKLE_VERSION}.tar.xz}\"\n          echo \"Downloading Sparkle CLI from: $CLI_URL\"\n          mkdir -p /tmp/sparkle\n          curl --fail --location --output /tmp/sparkle.tar.xz \"$CLI_URL\"\n          tar -xJf /tmp/sparkle.tar.xz -C /tmp/sparkle --strip-components=1\n          echo \"SPARKLE_BIN=/tmp/sparkle/bin\" >> $GITHUB_ENV\n\n      - name: Prepare code-signing keychain\n        env:\n          MACOS_CERTIFICATE: ${{ secrets.MACOS_CERTIFICATE }}\n          MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}\n          PROVISIONING_PROFILE: ${{ secrets.PROVISIONING_PROFILE }}\n        run: |\n          KEYCHAIN_PATH=\"$HOME/Library/Keychains/build.keychain-db\"\n\n          # Create fresh keychain\n          security create-keychain -p \"$MACOS_CERTIFICATE_PASSWORD\" \"$KEYCHAIN_PATH\"\n\n          # Disable auto-lock (no timeout, no lock-on-sleep)\n          security set-keychain-settings \"$KEYCHAIN_PATH\"\n\n          # Add to search list while preserving existing keychains\n          security list-keychains -d user -s \"$KEYCHAIN_PATH\" $(security list-keychains -d user | tr -d '\"')\n\n          # Set as default and unlock\n          security default-keychain -s \"$KEYCHAIN_PATH\"\n          security unlock-keychain -p \"$MACOS_CERTIFICATE_PASSWORD\" \"$KEYCHAIN_PATH\"\n\n          # Import certificate with full access for codesign\n          echo \"$MACOS_CERTIFICATE\" | base64 --decode > /tmp/cert.p12\n          security import /tmp/cert.p12 -k \"$KEYCHAIN_PATH\" -P \"$MACOS_CERTIFICATE_PASSWORD\" \\\n            -T /usr/bin/codesign -T /usr/bin/security -T /usr/bin/productbuild\n          rm /tmp/cert.p12\n\n          # Allow codesign to access the key without prompting\n          security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k \"$MACOS_CERTIFICATE_PASSWORD\" \"$KEYCHAIN_PATH\"\n\n          # Verify keychain is unlocked and identity is available\n          echo \"Verifying signing identity...\"\n          security find-identity -v -p codesigning \"$KEYCHAIN_PATH\"\n\n          # Setup provisioning profile\n          mkdir -p \"$HOME/Library/Developer/Xcode/UserData/Provisioning Profiles\"\n          echo \"$PROVISIONING_PROFILE\" | base64 --decode > \"$HOME/Library/Developer/Xcode/UserData/Provisioning Profiles/EXO.provisionprofile\"\n\n          # Export keychain path for other steps\n          echo \"BUILD_KEYCHAIN_PATH=$KEYCHAIN_PATH\" >> $GITHUB_ENV\n\n      # ============================================================\n      # Build the bundle\n      # ============================================================\n\n      - name: Build PyInstaller bundle\n        run: uv run pyinstaller packaging/pyinstaller/exo.spec\n\n      - name: Build Swift app\n        env:\n          MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}\n          SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}\n          SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}\n        run: |\n          cd app/EXO\n          security unlock-keychain -p \"$MACOS_CERTIFICATE_PASSWORD\" \"$BUILD_KEYCHAIN_PATH\"\n          SIGNING_IDENTITY=$(security find-identity -v -p codesigning \"$BUILD_KEYCHAIN_PATH\" | awk -F '\"' '{print $2}')\n          xcodebuild clean build \\\n            -scheme EXO \\\n            -configuration Release \\\n            -derivedDataPath build \\\n            MARKETING_VERSION=\"$RELEASE_VERSION\" \\\n            CURRENT_PROJECT_VERSION=\"$EXO_BUILD_VERSION\" \\\n            EXO_BUILD_TAG=\"$RELEASE_VERSION\" \\\n            EXO_BUILD_COMMIT=\"$GITHUB_SHA\" \\\n            SPARKLE_FEED_URL=\"$SPARKLE_FEED_URL\" \\\n            SPARKLE_ED25519_PUBLIC=\"$SPARKLE_ED25519_PUBLIC\" \\\n            EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT=\"$EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT\" \\\n            CODE_SIGNING_IDENTITY=\"$SIGNING_IDENTITY\" \\\n            CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES\n          mkdir -p ../../output\n          cp -R build/Build/Products/Release/EXO.app ../../output/EXO.app\n\n      - name: Inject PyInstaller runtime\n        run: |\n          rm -rf output/EXO.app/Contents/Resources/exo\n          mkdir -p output/EXO.app/Contents/Resources\n          cp -R dist/exo output/EXO.app/Contents/Resources/exo\n\n      - name: Codesign PyInstaller runtime\n        env:\n          MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}\n        run: |\n          cd output\n          security unlock-keychain -p \"$MACOS_CERTIFICATE_PASSWORD\" \"$BUILD_KEYCHAIN_PATH\"\n          SIGNING_IDENTITY=$(security find-identity -v -p codesigning \"$BUILD_KEYCHAIN_PATH\" | awk -F '\"' '{print $2}')\n          RUNTIME_DIR=\"EXO.app/Contents/Resources/exo\"\n          find \"$RUNTIME_DIR\" -type f \\( -perm -111 -o -name \"*.dylib\" -o -name \"*.so\" \\) -print0 |\n            while IFS= read -r -d '' file; do\n              /usr/bin/codesign --force --timestamp --options runtime \\\n                --sign \"$SIGNING_IDENTITY\" \"$file\"\n            done\n\n      - name: Sign, notarize, and create DMG\n        env:\n          MACOS_CERTIFICATE_PASSWORD: ${{ secrets.MACOS_CERTIFICATE_PASSWORD }}\n          APPLE_NOTARIZATION_USERNAME: ${{ secrets.APPLE_NOTARIZATION_USERNAME }}\n          APPLE_NOTARIZATION_PASSWORD: ${{ secrets.APPLE_NOTARIZATION_PASSWORD }}\n          APPLE_NOTARIZATION_TEAM: ${{ secrets.APPLE_NOTARIZATION_TEAM }}\n        run: |\n          cd output\n          security unlock-keychain -p \"$MACOS_CERTIFICATE_PASSWORD\" \"$BUILD_KEYCHAIN_PATH\"\n          SIGNING_IDENTITY=$(security find-identity -v -p codesigning \"$BUILD_KEYCHAIN_PATH\" | awk -F '\"' '{print $2}')\n          /usr/bin/codesign --deep --force --timestamp --options runtime \\\n            --sign \"$SIGNING_IDENTITY\" EXO.app\n          mkdir -p dmg-root\n          cp -R EXO.app dmg-root/\n          ln -s /Applications dmg-root/Applications\n          DMG_NAME=\"EXO-${RELEASE_VERSION}.dmg\"\n          hdiutil create -volname \"EXO\" -srcfolder dmg-root -ov -format UDZO \"$DMG_NAME\"\n          /usr/bin/codesign --force --timestamp --options runtime \\\n            --sign \"$SIGNING_IDENTITY\" \"$DMG_NAME\"\n          if [[ -n \"$APPLE_NOTARIZATION_USERNAME\" ]]; then\n            SUBMISSION_OUTPUT=$(xcrun notarytool submit \"$DMG_NAME\" \\\n              --apple-id \"$APPLE_NOTARIZATION_USERNAME\" \\\n              --password \"$APPLE_NOTARIZATION_PASSWORD\" \\\n              --team-id \"$APPLE_NOTARIZATION_TEAM\" \\\n              --wait --timeout 15m 2>&1)\n            echo \"$SUBMISSION_OUTPUT\"\n\n            SUBMISSION_ID=$(echo \"$SUBMISSION_OUTPUT\" | awk 'tolower($1)==\"id:\" && $2 ~ /^[0-9a-fA-F-]+$/ {print $2; exit}')\n            STATUS=$(echo \"$SUBMISSION_OUTPUT\" | awk 'tolower($1)==\"status:\" {print $2; exit}')\n\n            if [[ -n \"$SUBMISSION_ID\" ]]; then\n              xcrun notarytool log \"$SUBMISSION_ID\" \\\n                --apple-id \"$APPLE_NOTARIZATION_USERNAME\" \\\n                --password \"$APPLE_NOTARIZATION_PASSWORD\" \\\n                --team-id \"$APPLE_NOTARIZATION_TEAM\" > notarization-log.txt || true\n              echo \"===== Notarization Log =====\"\n              cat notarization-log.txt\n              echo \"============================\"\n            fi\n\n            if [[ \"$STATUS\" != \"Accepted\" ]]; then\n              echo \"Notarization failed with status: ${STATUS:-Unknown}\"\n              exit 1\n            fi\n\n            xcrun stapler staple \"$DMG_NAME\"\n          fi\n\n      - name: Generate Sparkle appcast\n        env:\n          SPARKLE_DOWNLOAD_PREFIX: ${{ env.SPARKLE_DOWNLOAD_PREFIX }}\n          SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}\n          IS_ALPHA: ${{ env.IS_ALPHA }}\n        run: |\n          set -euo pipefail\n          cd output\n          DOWNLOAD_PREFIX=\"${SPARKLE_DOWNLOAD_PREFIX:-https://assets.exolabs.net}\"\n          echo \"$SPARKLE_ED25519_PRIVATE\" > sparkle_ed25519.key\n          chmod 600 sparkle_ed25519.key\n\n          CHANNEL_FLAG=\"\"\n          if [[ \"$IS_ALPHA\" == \"true\" ]]; then\n            CHANNEL_FLAG=\"--channel alpha\"\n            echo \"Generating appcast for alpha channel\"\n          fi\n\n          $SPARKLE_BIN/generate_appcast \\\n            --ed-key-file sparkle_ed25519.key \\\n            --download-url-prefix \"$DOWNLOAD_PREFIX\" \\\n            $CHANNEL_FLAG \\\n            .\n\n      - name: Inject release notes into appcast\n        if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'\n        env:\n          RELEASE_VERSION: ${{ env.RELEASE_VERSION }}\n        run: |\n          # Inject markdown release notes with sparkle:format=\"markdown\" (Sparkle 2.9+)\n          export NOTES=$(cat \"$RELEASE_NOTES_FILE\")\n\n          # Insert description after the enclosure tag for this version\n          awk '\n            /<enclosure[^>]*>/ && index($0, ENVIRON[\"RELEASE_VERSION\"]) {\n              print\n              print \"            <description sparkle:format=\\\"markdown\\\"><![CDATA[\"\n              print ENVIRON[\"NOTES\"]\n              print \"            ]]></description>\"\n              next\n            }\n            { print }\n          ' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml\n\n          echo \"Injected markdown release notes for version $RELEASE_VERSION\"\n\n      # ============================================================\n      # Upload artifacts\n      # ============================================================\n\n      - name: Upload DMG\n        uses: actions/upload-artifact@v4\n        with:\n          name: EXO-dmg-${{ env.RELEASE_VERSION }}\n          path: output/EXO-${{ env.RELEASE_VERSION }}.dmg\n\n      - name: Upload to S3\n        if: env.SPARKLE_S3_BUCKET != ''\n        env:\n          AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}\n          AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}\n          AWS_REGION: ${{ env.AWS_REGION }}\n          SPARKLE_S3_BUCKET: ${{ env.SPARKLE_S3_BUCKET }}\n          SPARKLE_S3_PREFIX: ${{ env.SPARKLE_S3_PREFIX }}\n          IS_ALPHA: ${{ env.IS_ALPHA }}\n        run: |\n          set -euo pipefail\n          cd output\n          PREFIX=\"${SPARKLE_S3_PREFIX:-}\"\n          if [[ -n \"$PREFIX\" && \"${PREFIX: -1}\" != \"/\" ]]; then\n            PREFIX=\"${PREFIX}/\"\n          fi\n          DMG_NAME=\"EXO-${RELEASE_VERSION}.dmg\"\n\n          if [[ \"${{ github.ref_type }}\" != \"tag\" ]]; then\n            aws s3 cp \"$DMG_NAME\" \"s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-${GITHUB_SHA}.dmg\"\n            exit 0\n          fi\n\n          aws s3 cp \"$DMG_NAME\" \"s3://${SPARKLE_S3_BUCKET}/${PREFIX}${DMG_NAME}\"\n          if [[ \"$IS_ALPHA\" != \"true\" ]]; then\n            aws s3 cp \"$DMG_NAME\" \"s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg\"\n            aws s3 cp appcast.xml \"s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml\" --content-type application/xml --cache-control no-cache\n          fi\n\n      - name: Publish GitHub Release\n        if: github.ref_type == 'tag'\n        env:\n          GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n        run: |\n          DMG_PATH=\"output/EXO-${RELEASE_VERSION}.dmg\"\n\n          if [[ \"$HAS_RELEASE_NOTES\" == \"true\" ]]; then\n            # Update the draft release with the tag and upload DMG\n            gh api --method PATCH \"repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID\" \\\n              -f tag_name=\"$GITHUB_REF_NAME\" \\\n              -F draft=false\n            gh release upload \"$GITHUB_REF_NAME\" \"$DMG_PATH\" --clobber\n            echo \"Published release $GITHUB_REF_NAME with DMG attached\"\n          else\n            # Alpha without draft release - create one with auto-generated notes\n            gh release create \"$GITHUB_REF_NAME\" \"$DMG_PATH\" \\\n              --title \"$GITHUB_REF_NAME\" \\\n              --generate-notes \\\n              --prerelease\n            echo \"Created alpha release $GITHUB_REF_NAME with auto-generated notes\"\n          fi\n"
  },
  {
    "path": ".github/workflows/pipeline.yml",
    "content": "name: ci-pipeline\n\non:\n  push:\n  pull_request:\n    branches:\n      - staging\n      - main\n\njobs:\n  nix:\n    name: Build and check (${{ matrix.system }})\n    runs-on: ${{ matrix.runner }}\n    strategy:\n      fail-fast: false\n      matrix:\n        include:\n          - runner: macos-26\n            system: aarch64-darwin\n          - runner: ubuntu-latest\n            system: x86_64-linux\n          - runner: ubuntu-24.04-arm\n            system: aarch64-linux\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          lfs: false\n\n      - uses: cachix/install-nix-action@v31\n        with:\n          nix_path: nixpkgs=channel:nixos-unstable\n\n      - uses: cachix/cachix-action@v14\n        name: Configure Cachix\n        with:\n          name: exo\n          authToken: \"${{ secrets.CACHIX_AUTH_TOKEN }}\"\n\n      - name: Build Metal packages (macOS only)\n        if: runner.os == 'macOS'\n        run: |\n          # Try to build metal-toolchain first (may succeed via cachix cache hit)\n          if nix build .#metal-toolchain 2>/dev/null; then\n            echo \"metal-toolchain built successfully (likely cache hit)\"\n          else\n            echo \"metal-toolchain build failed, extracting from Xcode...\"\n\n            NAR_HASH=\"sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw=\"\n            NAR_NAME=\"metal-toolchain-17C48.nar\"\n\n            # Use RUNNER_TEMP to avoid /tmp symlink issues on macOS\n            WORK_DIR=\"${RUNNER_TEMP}/metal-work\"\n            mkdir -p \"$WORK_DIR\"\n\n            # Download the Metal toolchain component\n            xcodebuild -downloadComponent MetalToolchain\n\n            # Find and mount the DMG\n            DMG_PATH=$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' 2>/dev/null | head -1)\n            if [ -z \"$DMG_PATH\" ]; then\n              echo \"Error: Could not find Metal toolchain DMG\"\n              exit 1\n            fi\n\n            echo \"Found DMG at: $DMG_PATH\"\n            hdiutil attach \"$DMG_PATH\" -mountpoint \"${WORK_DIR}/metal-dmg\"\n\n            # Copy the toolchain\n            cp -R \"${WORK_DIR}/metal-dmg/Metal.xctoolchain\" \"${WORK_DIR}/metal-export\"\n            hdiutil detach \"${WORK_DIR}/metal-dmg\"\n\n            # Create NAR and add to store\n            nix nar pack \"${WORK_DIR}/metal-export\" > \"${WORK_DIR}/${NAR_NAME}\"\n            STORE_PATH=$(nix store add --mode flat \"${WORK_DIR}/${NAR_NAME}\")\n            echo \"Added NAR to store: $STORE_PATH\"\n\n            # Verify the hash matches\n            ACTUAL_HASH=$(nix hash file \"${WORK_DIR}/${NAR_NAME}\")\n            if [ \"$ACTUAL_HASH\" != \"$NAR_HASH\" ]; then\n              echo \"Warning: NAR hash mismatch!\"\n              echo \"Expected: $NAR_HASH\"\n              echo \"Actual:   $ACTUAL_HASH\"\n              echo \"The metal-toolchain.nix may need updating\"\n            fi\n\n            # Clean up\n            rm -rf \"$WORK_DIR\"\n\n            # Retry the build now that NAR is in store\n            nix build .#metal-toolchain\n          fi\n\n          # Build mlx (depends on metal-toolchain)\n          nix build .#mlx\n\n      - name: Build all Nix outputs\n        run: |\n          nix flake show --json | jq -r '\n            [\n              (.packages.\"${{ matrix.system }}\" // {} | keys[] | \".#packages.${{ matrix.system }}.\\(.)\"),\n              (.devShells.\"${{ matrix.system }}\" // {} | keys[] | \".#devShells.${{ matrix.system }}.\\(.)\")\n            ] | .[]\n          ' | xargs nix build\n\n      - name: Run nix flake check\n        run: nix flake check\n\n      - name: Run pytest (macOS only)\n        if: runner.os == 'macOS'\n        run: |\n          # Build the test environment (requires relaxed sandbox for uv2nix on macOS)\n          TEST_ENV=$(nix build '.#exo-test-env' --option sandbox relaxed --print-out-paths)\n\n          # Run pytest outside sandbox (needs GPU access for MLX)\n          export HOME=\"$RUNNER_TEMP\"\n          export EXO_TESTS=1\n          export EXO_DASHBOARD_DIR=\"$PWD/dashboard/\" \n          export EXO_RESOURCES_DIR=\"$PWD/resources\" \n          $TEST_ENV/bin/python -m pytest src -m \"not slow\" --import-mode=importlib\n"
  },
  {
    "path": ".gitignore",
    "content": "# gitingest\ndigest.txt\n\n# python\n**/__pycache__\n\n# nix\n.direnv/\n\n# IDEA (PyCharm)\n.idea\n\n# xcode / macos\n*.xcuserstate\n*.xcuserdata\n*.xcuserdatad/\n**/.DS_Store\napp/EXO/build/\ndist/\n\n\n# rust\ntarget/\n**/*.rs.bk\n*.pdb\n\n# svelte\ndashboard/build/\ndashboard/node_modules/\ndashboard/.svelte-kit/\n\n# host config snapshots\nhosts_*.json\n.swp\n\n# bench files\nbench/**/*.json\n\n# tmp\ntmp/models\n"
  },
  {
    "path": ".mlx_typings/.gitkeep",
    "content": ""
  },
  {
    "path": ".mlx_typings/mflux/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport os\n\nif \"TOKENIZERS_PARALLELISM\" not in os.environ: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/callbacks/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/callbacks/callback.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport PIL.Image\nimport tqdm\nfrom typing import Protocol\nfrom mflux.models.common.config.config import Config\n\nclass BeforeLoopCallback(Protocol):\n    def call_before_loop(\n        self,\n        seed: int,\n        prompt: str,\n        latents: mx.array,\n        config: Config,\n        canny_image: PIL.Image.Image | None = ...,\n        depth_image: PIL.Image.Image | None = ...,\n    ) -> None: ...\n\nclass InLoopCallback(Protocol):\n    def call_in_loop(\n        self,\n        t: int,\n        seed: int,\n        prompt: str,\n        latents: mx.array,\n        config: Config,\n        time_steps: tqdm,\n    ) -> None: ...\n\nclass AfterLoopCallback(Protocol):\n    def call_after_loop(\n        self, seed: int, prompt: str, latents: mx.array, config: Config\n    ) -> None: ...\n\nclass InterruptCallback(Protocol):\n    def call_interrupt(\n        self,\n        t: int,\n        seed: int,\n        prompt: str,\n        latents: mx.array,\n        config: Config,\n        time_steps: tqdm,\n    ) -> None: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/callbacks/callback_registry.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import TYPE_CHECKING\nfrom mflux.callbacks.callback import (\n    AfterLoopCallback,\n    BeforeLoopCallback,\n    InLoopCallback,\n    InterruptCallback,\n)\nfrom mflux.callbacks.generation_context import GenerationContext\nfrom mflux.models.common.config.config import Config\n\nif TYPE_CHECKING: ...\n\nclass CallbackRegistry:\n    def __init__(self) -> None: ...\n    def register(self, callback) -> None: ...\n    def start(self, seed: int, prompt: str, config: Config) -> GenerationContext: ...\n    def before_loop_callbacks(self) -> list[BeforeLoopCallback]: ...\n    def in_loop_callbacks(self) -> list[InLoopCallback]: ...\n    def after_loop_callbacks(self) -> list[AfterLoopCallback]: ...\n    def interrupt_callbacks(self) -> list[InterruptCallback]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/callbacks/generation_context.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport PIL.Image\nimport tqdm\nfrom typing import TYPE_CHECKING\nfrom mflux.callbacks.callback_registry import CallbackRegistry\nfrom mflux.models.common.config.config import Config\n\nif TYPE_CHECKING: ...\n\nclass GenerationContext:\n    def __init__(\n        self, registry: CallbackRegistry, seed: int, prompt: str, config: Config\n    ) -> None: ...\n    def before_loop(\n        self,\n        latents: mx.array,\n        *,\n        canny_image: PIL.Image.Image | None = ...,\n        depth_image: PIL.Image.Image | None = ...,\n    ) -> None: ...\n    def in_loop(self, t: int, latents: mx.array, time_steps: tqdm = ...) -> None: ...\n    def after_loop(self, latents: mx.array) -> None: ...\n    def interruption(\n        self, t: int, latents: mx.array, time_steps: tqdm = ...\n    ) -> None: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/cli/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/cli/defaults/defaults.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport os\n\nBATTERY_PERCENTAGE_STOP_LIMIT = ...\nCONTROLNET_STRENGTH = ...\nDEFAULT_DEV_FILL_GUIDANCE = ...\nDEFAULT_DEPTH_GUIDANCE = ...\nDIMENSION_STEP_PIXELS = ...\nGUIDANCE_SCALE = ...\nGUIDANCE_SCALE_KONTEXT = ...\nIMAGE_STRENGTH = ...\nMODEL_CHOICES = ...\nMODEL_INFERENCE_STEPS = ...\nQUANTIZE_CHOICES = ...\nif os.environ.get(\"MFLUX_CACHE_DIR\"):\n    MFLUX_CACHE_DIR = ...\nelse:\n    MFLUX_CACHE_DIR = ...\nMFLUX_LORA_CACHE_DIR = ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/cli/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/config/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.common.config.model_config import ModelConfig\n\n__all__ = [\"Config\", \"ModelConfig\"]\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/config/config.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom pathlib import Path\nfrom typing import Any\nfrom tqdm import tqdm\nfrom mflux.models.common.config.model_config import ModelConfig\n\nlogger = ...\n\nclass Config:\n    def __init__(\n        self,\n        model_config: ModelConfig,\n        num_inference_steps: int = ...,\n        height: int = ...,\n        width: int = ...,\n        guidance: float = ...,\n        image_path: Path | str | None = ...,\n        image_strength: float | None = ...,\n        depth_image_path: Path | str | None = ...,\n        redux_image_paths: list[Path | str] | None = ...,\n        redux_image_strengths: list[float] | None = ...,\n        masked_image_path: Path | str | None = ...,\n        controlnet_strength: float | None = ...,\n        scheduler: str = ...,\n    ) -> None: ...\n    @property\n    def height(self) -> int: ...\n    @property\n    def width(self) -> int: ...\n    @width.setter\n    def width(self, value):  # -> None:\n        ...\n    @property\n    def image_seq_len(self) -> int: ...\n    @property\n    def guidance(self) -> float: ...\n    @property\n    def num_inference_steps(self) -> int: ...\n    @property\n    def precision(self) -> mx.Dtype: ...\n    @property\n    def num_train_steps(self) -> int: ...\n    @property\n    def image_path(self) -> Path | None: ...\n    @property\n    def image_strength(self) -> float | None: ...\n    @property\n    def depth_image_path(self) -> Path | None: ...\n    @property\n    def redux_image_paths(self) -> list[Path] | None: ...\n    @property\n    def redux_image_strengths(self) -> list[float] | None: ...\n    @property\n    def masked_image_path(self) -> Path | None: ...\n    @property\n    def init_time_step(self) -> int: ...\n    @property\n    def time_steps(self) -> tqdm: ...\n    @property\n    def controlnet_strength(self) -> float | None: ...\n    @property\n    def scheduler(self) -> Any: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/config/model_config.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom functools import lru_cache\nfrom typing import Literal\n\nclass ModelConfig:\n    precision: mx.Dtype = ...\n    def __init__(\n        self,\n        priority: int,\n        aliases: list[str],\n        model_name: str,\n        base_model: str | None,\n        controlnet_model: str | None,\n        custom_transformer_model: str | None,\n        num_train_steps: int | None,\n        max_sequence_length: int | None,\n        supports_guidance: bool | None,\n        requires_sigma_shift: bool | None,\n        transformer_overrides: dict | None = ...,\n    ) -> None: ...\n    @staticmethod\n    @lru_cache\n    def dev() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def schnell() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def dev_kontext() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def dev_fill() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def dev_redux() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def dev_depth() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def dev_controlnet_canny() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def schnell_controlnet_canny() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def dev_controlnet_upscaler() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def dev_fill_catvton() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def krea_dev() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def flux2_klein_4b() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def flux2_klein_9b() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def qwen_image() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def qwen_image_edit() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def fibo() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def z_image_turbo() -> ModelConfig: ...\n    @staticmethod\n    @lru_cache\n    def seedvr2_3b() -> ModelConfig: ...\n    def x_embedder_input_dim(self) -> int: ...\n    def is_canny(self) -> bool: ...\n    @staticmethod\n    def from_name(\n        model_name: str, base_model: Literal[\"dev\", \"schnell\", \"krea-dev\"] | None = ...\n    ) -> ModelConfig: ...\n\nAVAILABLE_MODELS = ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/latent_creator/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/latent_creator/latent_creator.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, TypeAlias\nfrom mlx import nn\nfrom mflux.models.common.vae.tiling_config import TilingConfig\nfrom mflux.models.fibo.latent_creator.fibo_latent_creator import FiboLatentCreator\nfrom mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator\nfrom mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator\nfrom mflux.models.z_image.latent_creator.z_image_latent_creator import (\n    ZImageLatentCreator,\n)\n\nif TYPE_CHECKING:\n    LatentCreatorType: TypeAlias = type[\n        FiboLatentCreator | FluxLatentCreator | QwenLatentCreator | ZImageLatentCreator\n    ]\n\nclass Img2Img:\n    def __init__(\n        self,\n        vae: nn.Module,\n        latent_creator: LatentCreatorType,\n        sigmas: mx.array,\n        init_time_step: int,\n        image_path: str | Path | None,\n        tiling_config: TilingConfig | None = ...,\n    ) -> None: ...\n\nclass LatentCreator:\n    @staticmethod\n    def create_for_txt2img_or_img2img(\n        seed: int, height: int, width: int, img2img: Img2Img\n    ) -> mx.array: ...\n    @staticmethod\n    def encode_image(\n        vae: nn.Module,\n        image_path: str | Path,\n        height: int,\n        width: int,\n        tiling_config: TilingConfig | None = ...,\n    ) -> mx.array: ...\n    @staticmethod\n    def add_noise_by_interpolation(\n        clean: mx.array, noise: mx.array, sigma: float\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/lora/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/lora/layer/fused_linear_lora_layer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mlx import nn\nfrom mflux.models.common.lora.layer.linear_lora_layer import LoRALinear\n\nclass FusedLoRALinear(nn.Module):\n    def __init__(\n        self, base_linear: nn.Linear | nn.QuantizedLinear, loras: list[LoRALinear]\n    ) -> None: ...\n    def __call__(self, x):  # -> array:\n        ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/lora/layer/linear_lora_layer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mlx import nn\n\nclass LoRALinear(nn.Module):\n    @staticmethod\n    def from_linear(\n        linear: nn.Linear | nn.QuantizedLinear, r: int = ..., scale: float = ...\n    ):  # -> LoRALinear:\n        ...\n    def __init__(\n        self,\n        input_dims: int,\n        output_dims: int,\n        r: int = ...,\n        scale: float = ...,\n        bias: bool = ...,\n    ) -> None: ...\n    def __call__(self, x):  # -> array:\n        ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/lora/mapping/lora_loader.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom collections.abc import Callable\nfrom dataclasses import dataclass\nfrom mflux.models.common.lora.mapping.lora_mapping import LoRATarget\n\n@dataclass\nclass PatternMatch:\n    source_pattern: str\n    target_path: str\n    matrix_name: str\n    transpose: bool\n    transform: Callable[[mx.array], mx.array] | None = ...\n\nclass LoRALoader:\n    @staticmethod\n    def load_and_apply_lora(\n        lora_mapping: list[LoRATarget],\n        transformer: nn.Module,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n    ) -> tuple[list[str], list[float]]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/lora/mapping/lora_mapping.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom collections.abc import Callable\nfrom dataclasses import dataclass\nfrom typing import List, Protocol\n\n@dataclass\nclass LoRATarget:\n    model_path: str\n    possible_up_patterns: List[str]\n    possible_down_patterns: List[str]\n    possible_alpha_patterns: List[str] = ...\n    up_transform: Callable[[mx.array], mx.array] | None = ...\n    down_transform: Callable[[mx.array], mx.array] | None = ...\n\nclass LoRAMapping(Protocol):\n    @staticmethod\n    def get_mapping() -> List[LoRATarget]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/lora/mapping/lora_saver.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.nn as nn\n\nclass LoRASaver:\n    @staticmethod\n    def bake_and_strip_lora(module: nn.Module) -> nn.Module: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/lora/mapping/lora_transforms.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\n\nclass LoraTransforms:\n    @staticmethod\n    def split_q_up(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def split_k_up(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def split_v_up(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def split_q_down(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def split_k_down(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def split_v_down(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def split_single_q_up(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def split_single_k_up(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def split_single_v_up(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def split_single_mlp_up(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def split_single_q_down(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def split_single_k_down(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def split_single_v_down(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def split_single_mlp_down(tensor: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/resolution/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.common.resolution.config_resolution import ConfigResolution\nfrom mflux.models.common.resolution.lora_resolution import LoraResolution\nfrom mflux.models.common.resolution.path_resolution import PathResolution\nfrom mflux.models.common.resolution.quantization_resolution import (\n    QuantizationResolution,\n)\n\n__all__ = [\n    \"ConfigResolution\",\n    \"LoraResolution\",\n    \"PathResolution\",\n    \"QuantizationResolution\",\n]\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/resolution/actions.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom enum import Enum\nfrom typing import NamedTuple\n\nclass QuantizationAction(Enum):\n    NONE = ...\n    STORED = ...\n    REQUESTED = ...\n\nclass PathAction(Enum):\n    LOCAL = ...\n    HUGGINGFACE_CACHED = ...\n    HUGGINGFACE = ...\n    ERROR = ...\n\nclass LoraAction(Enum):\n    LOCAL = ...\n    REGISTRY = ...\n    HUGGINGFACE_COLLECTION_CACHED = ...\n    HUGGINGFACE_COLLECTION = ...\n    HUGGINGFACE_REPO_CACHED = ...\n    HUGGINGFACE_REPO = ...\n    ERROR = ...\n\nclass ConfigAction(Enum):\n    EXACT_MATCH = ...\n    EXPLICIT_BASE = ...\n    INFER_SUBSTRING = ...\n    ERROR = ...\n\nclass Rule(NamedTuple):\n    priority: int\n    name: str\n    check: str\n    action: QuantizationAction | PathAction | LoraAction | ConfigAction\n    ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/resolution/config_resolution.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import TYPE_CHECKING\nfrom mflux.models.common.config.model_config import ModelConfig\n\nif TYPE_CHECKING: ...\nlogger = ...\n\nclass ConfigResolution:\n    RULES = ...\n    @staticmethod\n    def resolve(model_name: str, base_model: str | None = ...) -> ModelConfig: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/resolution/lora_resolution.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom pathlib import Path\n\nlogger = ...\n\nclass LoraResolution:\n    RULES = ...\n    _registry: dict[str, Path] = ...\n    @staticmethod\n    def resolve(path: str) -> str: ...\n    @staticmethod\n    def resolve_paths(paths: list[str] | None) -> list[str]: ...\n    @staticmethod\n    def resolve_scales(scales: list[float] | None, num_paths: int) -> list[float]: ...\n    @staticmethod\n    def get_registry() -> dict[str, Path]: ...\n    @staticmethod\n    def discover_files(library_paths: list[Path]) -> dict[str, Path]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/resolution/path_resolution.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom pathlib import Path\n\nlogger = ...\n\nclass PathResolution:\n    RULES = ...\n    @staticmethod\n    def resolve(path: str | None, patterns: list[str] | None = ...) -> Path | None: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/resolution/quantization_resolution.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nlogger = ...\n\nclass QuantizationResolution:\n    RULES = ...\n    @staticmethod\n    def resolve(\n        stored: int | None, requested: int | None\n    ) -> tuple[int | None, str | None]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/schedulers/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom .flow_match_euler_discrete_scheduler import FlowMatchEulerDiscreteScheduler\nfrom .linear_scheduler import LinearScheduler\nfrom .seedvr2_euler_scheduler import SeedVR2EulerScheduler\n\n__all__ = [\n    \"LinearScheduler\",\n    \"FlowMatchEulerDiscreteScheduler\",\n    \"SeedVR2EulerScheduler\",\n]\n\nclass SchedulerModuleNotFound(ValueError): ...\nclass SchedulerClassNotFound(ValueError): ...\nclass InvalidSchedulerType(TypeError): ...\n\nSCHEDULER_REGISTRY = ...\n\ndef register_contrib(scheduler_object, scheduler_name=...):  # -> None:\n    ...\ndef try_import_external_scheduler(\n    scheduler_object_path: str,\n):  # -> type[BaseScheduler]:\n    ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/schedulers/base_scheduler.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom abc import ABC, abstractmethod\n\nclass BaseScheduler(ABC):\n    @property\n    @abstractmethod\n    def sigmas(self) -> mx.array: ...\n    @abstractmethod\n    def step(\n        self, noise: mx.array, timestep: int, latents: mx.array, **kwargs\n    ) -> mx.array: ...\n    def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/schedulers/flow_match_euler_discrete_scheduler.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom typing import TYPE_CHECKING\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.common.schedulers.base_scheduler import BaseScheduler\n\nif TYPE_CHECKING: ...\n\nclass FlowMatchEulerDiscreteScheduler(BaseScheduler):\n    def __init__(self, config: Config) -> None: ...\n    @property\n    def sigmas(self) -> mx.array: ...\n    @property\n    def timesteps(self) -> mx.array: ...\n    def set_image_seq_len(self, image_seq_len: int) -> None: ...\n    @staticmethod\n    def get_timesteps_and_sigmas(\n        image_seq_len: int, num_inference_steps: int, num_train_timesteps: int = ...\n    ) -> tuple[mx.array, mx.array]: ...\n    def step(\n        self, noise: mx.array, timestep: int, latents: mx.array, **kwargs\n    ) -> mx.array: ...\n    def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/schedulers/linear_scheduler.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom typing import TYPE_CHECKING\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.common.schedulers.base_scheduler import BaseScheduler\n\nif TYPE_CHECKING: ...\n\nclass LinearScheduler(BaseScheduler):\n    def __init__(self, config: Config) -> None: ...\n    @property\n    def sigmas(self) -> mx.array: ...\n    @property\n    def timesteps(self) -> mx.array: ...\n    def step(\n        self, noise: mx.array, timestep: int, latents: mx.array, **kwargs\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/schedulers/seedvr2_euler_scheduler.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom typing import TYPE_CHECKING\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.common.schedulers.base_scheduler import BaseScheduler\n\nif TYPE_CHECKING: ...\n\nclass SeedVR2EulerScheduler(BaseScheduler):\n    def __init__(self, config: Config) -> None: ...\n    @property\n    def timesteps(self) -> mx.array: ...\n    @property\n    def sigmas(self) -> mx.array: ...\n    def step(\n        self, noise: mx.array, timestep: int, latents: mx.array, **kwargs\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/tokenizer/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.common.tokenizer.tokenizer import (\n    BaseTokenizer,\n    LanguageTokenizer,\n    Tokenizer,\n    VisionLanguageTokenizer,\n)\nfrom mflux.models.common.tokenizer.tokenizer_loader import TokenizerLoader\nfrom mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n__all__ = [\n    \"Tokenizer\",\n    \"BaseTokenizer\",\n    \"LanguageTokenizer\",\n    \"VisionLanguageTokenizer\",\n    \"TokenizerLoader\",\n    \"TokenizerOutput\",\n]\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/tokenizer/tokenizer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom abc import ABC, abstractmethod\nfrom typing import Protocol, runtime_checkable\nfrom PIL import Image\nfrom transformers import PreTrainedTokenizer\nfrom mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\n@runtime_checkable\nclass Tokenizer(Protocol):\n    tokenizer: PreTrainedTokenizer\n    def tokenize(\n        self,\n        prompt: str | list[str],\n        images: list[Image.Image] | None = ...,\n        max_length: int | None = ...,\n        **kwargs,\n    ) -> TokenizerOutput: ...\n\nclass BaseTokenizer(ABC):\n    def __init__(\n        self, tokenizer: PreTrainedTokenizer, max_length: int = ...\n    ) -> None: ...\n    @abstractmethod\n    def tokenize(\n        self,\n        prompt: str | list[str],\n        images: list[Image.Image] | None = ...,\n        max_length: int | None = ...,\n        **kwargs,\n    ) -> TokenizerOutput: ...\n\nclass LanguageTokenizer(BaseTokenizer):\n    def __init__(\n        self,\n        tokenizer: PreTrainedTokenizer,\n        max_length: int = ...,\n        padding: str = ...,\n        return_attention_mask: bool = ...,\n        template: str | None = ...,\n        use_chat_template: bool = ...,\n        chat_template_kwargs: dict | None = ...,\n        add_special_tokens: bool = ...,\n    ) -> None: ...\n    def tokenize(\n        self,\n        prompt: str | list[str],\n        images: list[Image.Image] | None = ...,\n        max_length: int | None = ...,\n        **kwargs,\n    ) -> TokenizerOutput: ...\n\nclass VisionLanguageTokenizer(BaseTokenizer):\n    def __init__(\n        self,\n        tokenizer: PreTrainedTokenizer,\n        processor,\n        max_length: int = ...,\n        template: str | None = ...,\n        image_token: str = ...,\n    ) -> None: ...\n    def tokenize(\n        self,\n        prompt: str | list[str],\n        images: list[Image.Image] | None = ...,\n        max_length: int | None = ...,\n        **kwargs,\n    ) -> TokenizerOutput: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/tokenizer/tokenizer_loader.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import TYPE_CHECKING\nfrom mflux.models.common.tokenizer.tokenizer import BaseTokenizer\nfrom mflux.models.common.weights.loading.weight_definition import TokenizerDefinition\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\nif TYPE_CHECKING: ...\n\nclass TokenizerLoader:\n    @staticmethod\n    def load(definition: TokenizerDefinition, model_path: str) -> BaseTokenizer: ...\n    @staticmethod\n    def load_all(\n        definitions: list[TokenizerDefinition],\n        model_path: str,\n        max_length_overrides: dict[str, int] | None = ...,\n    ) -> dict[str, BaseTokenizer]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/tokenizer/tokenizer_output.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom dataclasses import dataclass\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\n@dataclass\nclass TokenizerOutput:\n    input_ids: mx.array\n    attention_mask: mx.array\n    pixel_values: mx.array | None = ...\n    image_grid_thw: mx.array | None = ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/vae/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.common.vae.tiling_config import TilingConfig\nfrom mflux.models.common.vae.vae_tiler import VAETiler\n\n__all__ = [\"TilingConfig\", \"VAETiler\"]\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/vae/tiling_config.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom dataclasses import dataclass\n\n@dataclass(frozen=True, slots=True)\nclass TilingConfig:\n    vae_decode_tiles_per_dim: int | None = ...\n    vae_decode_overlap: int = ...\n    vae_encode_tiled: bool = ...\n    vae_encode_tile_size: int = ...\n    vae_encode_tile_overlap: int = ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/vae/vae_tiler.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom typing import Callable\n\nclass VAETiler:\n    @staticmethod\n    def encode_image_tiled(\n        *,\n        image: mx.array,\n        encode_fn: Callable[[mx.array], mx.array],\n        latent_channels: int,\n        tile_size: tuple[int, int] = ...,\n        tile_overlap: tuple[int, int] = ...,\n        spatial_scale: int = ...,\n    ) -> mx.array: ...\n    @staticmethod\n    def decode_image_tiled(\n        *,\n        latent: mx.array,\n        decode_fn: Callable[[mx.array], mx.array],\n        tile_size: tuple[int, int] = ...,\n        tile_overlap: tuple[int, int] = ...,\n        spatial_scale: int = ...,\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/vae/vae_util.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\nfrom mflux.models.common.vae.tiling_config import TilingConfig\n\nclass VAEUtil:\n    @staticmethod\n    def encode(\n        vae: nn.Module, image: mx.array, tiling_config: TilingConfig | None = ...\n    ) -> mx.array: ...\n    @staticmethod\n    def decode(\n        vae: nn.Module, latent: mx.array, tiling_config: TilingConfig | None = ...\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/weights/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.common.weights.loading.loaded_weights import LoadedWeights, MetaData\nfrom mflux.models.common.weights.loading.weight_applier import WeightApplier\nfrom mflux.models.common.weights.loading.weight_definition import ComponentDefinition\nfrom mflux.models.common.weights.loading.weight_loader import WeightLoader\nfrom mflux.models.common.weights.saving.model_saver import ModelSaver\n\n__all__ = [\n    \"ComponentDefinition\",\n    \"LoadedWeights\",\n    \"MetaData\",\n    \"ModelSaver\",\n    \"WeightApplier\",\n    \"WeightLoader\",\n]\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/weights/loading/loaded_weights.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom dataclasses import dataclass\n\n@dataclass\nclass MetaData:\n    quantization_level: int | None = ...\n    mflux_version: str | None = ...\n\n@dataclass\nclass LoadedWeights:\n    components: dict[str, dict]\n    meta_data: MetaData\n    def __getattr__(self, name: str) -> dict | None: ...\n    def num_transformer_blocks(self, component_name: str = ...) -> int: ...\n    def num_single_transformer_blocks(self, component_name: str = ...) -> int: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/weights/loading/weight_applier.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.nn as nn\nfrom typing import TYPE_CHECKING\nfrom mflux.models.common.weights.loading.loaded_weights import LoadedWeights\nfrom mflux.models.common.weights.loading.weight_definition import (\n    ComponentDefinition,\n    WeightDefinitionType,\n)\n\nif TYPE_CHECKING: ...\n\nclass WeightApplier:\n    @staticmethod\n    def apply_and_quantize_single(\n        weights: LoadedWeights,\n        model: nn.Module,\n        component: ComponentDefinition,\n        quantize_arg: int | None,\n        quantization_predicate=...,\n    ) -> int | None: ...\n    @staticmethod\n    def apply_and_quantize(\n        weights: LoadedWeights,\n        models: dict[str, nn.Module],\n        quantize_arg: int | None,\n        weight_definition: WeightDefinitionType,\n    ) -> int | None: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/weights/loading/weight_definition.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom dataclasses import dataclass\nfrom typing import Callable, List, TYPE_CHECKING, TypeAlias\nfrom mflux.models.common.weights.mapping.weight_mapping import WeightTarget\nfrom mflux.models.common.tokenizer.tokenizer import BaseTokenizer\nfrom mflux.models.depth_pro.weights.depth_pro_weight_definition import (\n    DepthProWeightDefinition,\n)\nfrom mflux.models.fibo.weights.fibo_weight_definition import FIBOWeightDefinition\nfrom mflux.models.fibo_vlm.weights.fibo_vlm_weight_definition import (\n    FIBOVLMWeightDefinition,\n)\nfrom mflux.models.flux.weights.flux_weight_definition import FluxWeightDefinition\nfrom mflux.models.qwen.weights.qwen_weight_definition import QwenWeightDefinition\nfrom mflux.models.seedvr2.weights.seedvr2_weight_definition import (\n    SeedVR2WeightDefinition,\n)\nfrom mflux.models.z_image.weights.z_image_weight_definition import (\n    ZImageWeightDefinition,\n)\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\nif TYPE_CHECKING:\n    WeightDefinitionType: TypeAlias = type[\n        FluxWeightDefinition\n        | FIBOWeightDefinition\n        | FIBOVLMWeightDefinition\n        | QwenWeightDefinition\n        | ZImageWeightDefinition\n        | SeedVR2WeightDefinition\n        | DepthProWeightDefinition\n    ]\n\n@dataclass\nclass ComponentDefinition:\n    name: str\n    hf_subdir: str\n    mapping_getter: Callable[[], List[WeightTarget]] | None = ...\n    model_attr: str | None = ...\n    num_blocks: int | None = ...\n    num_layers: int | None = ...\n    loading_mode: str = ...\n    precision: mx.Dtype | None = ...\n    skip_quantization: bool = ...\n    bulk_transform: Callable[[mx.array], mx.array] | None = ...\n    weight_subkey: str | None = ...\n    download_url: str | None = ...\n    weight_prefix_filters: List[str] | None = ...\n    weight_files: List[str] | None = ...\n\n@dataclass\nclass TokenizerDefinition:\n    name: str\n    hf_subdir: str\n    tokenizer_class: str = ...\n    fallback_subdirs: List[str] | None = ...\n    download_patterns: List[str] | None = ...\n    encoder_class: type[BaseTokenizer] | None = ...\n    max_length: int = ...\n    padding: str = ...\n    template: str | None = ...\n    use_chat_template: bool = ...\n    chat_template_kwargs: dict | None = ...\n    add_special_tokens: bool = ...\n    processor_class: type | None = ...\n    image_token: str = ...\n    chat_template: str | None = ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/weights/loading/weight_loader.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import TYPE_CHECKING\nfrom mflux.models.common.weights.loading.loaded_weights import LoadedWeights\nfrom mflux.models.common.weights.loading.weight_definition import (\n    ComponentDefinition,\n    WeightDefinitionType,\n)\n\nif TYPE_CHECKING: ...\nlogger = ...\n\nclass WeightLoader:\n    @staticmethod\n    def load_single(\n        component: ComponentDefinition, repo_id: str, file_pattern: str = ...\n    ) -> LoadedWeights: ...\n    @staticmethod\n    def load(\n        weight_definition: WeightDefinitionType, model_path: str | None = ...\n    ) -> LoadedWeights: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/weights/mapping/weight_mapper.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom typing import Dict, List, Optional\nfrom mflux.models.common.weights.mapping.weight_mapping import WeightTarget\n\nclass WeightMapper:\n    @staticmethod\n    def apply_mapping(\n        hf_weights: Dict[str, mx.array],\n        mapping: List[WeightTarget],\n        num_blocks: Optional[int] = ...,\n        num_layers: Optional[int] = ...,\n    ) -> Dict: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/weights/mapping/weight_mapping.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom dataclasses import dataclass\nfrom typing import Callable, List, Optional, Protocol\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\n@dataclass\nclass WeightTarget:\n    to_pattern: str\n    from_pattern: List[str]\n    transform: Optional[Callable[[mx.array], mx.array]] = ...\n    required: bool = ...\n    max_blocks: Optional[int] = ...\n\nclass WeightMapping(Protocol):\n    @staticmethod\n    def get_mapping() -> List[WeightTarget]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/weights/mapping/weight_transforms.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\n\nclass WeightTransforms:\n    @staticmethod\n    def reshape_gamma_to_1d(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def transpose_patch_embed(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def transpose_conv3d_weight(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def transpose_conv2d_weight(tensor: mx.array) -> mx.array: ...\n    @staticmethod\n    def transpose_conv_transpose2d_weight(tensor: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/common/weights/saving/model_saver.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Any, TYPE_CHECKING\nfrom mflux.models.common.weights.loading.weight_definition import WeightDefinitionType\n\nif TYPE_CHECKING: ...\n\nclass ModelSaver:\n    @staticmethod\n    def save_model(\n        model: Any, bits: int, base_path: str, weight_definition: WeightDefinitionType\n    ) -> None: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/depth_pro_initializer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.depth_pro.model.depth_pro_model import DepthProModel\n\nclass DepthProInitializer:\n    @staticmethod\n    def init(model: DepthProModel, quantize: int | None = ...) -> None: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/decoder/feature_fusion_block_2d.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass FeatureFusionBlock2d(nn.Module):\n    def __init__(self, num_features: int, deconv: bool = ...) -> None: ...\n    def __call__(self, x0: mx.array, x1: mx.array | None = ...) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/decoder/multires_conv_decoder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass MultiresConvDecoder(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(\n        self,\n        x0_latent: mx.array,\n        x1_latent: mx.array,\n        x0_features: mx.array,\n        x1_features: mx.array,\n        x_global_features: mx.array,\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/decoder/residual_block.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, num_features: int) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/depth_pro.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom PIL import Image\n\n@dataclass\nclass DepthResult:\n    depth_image: Image.Image\n    depth_array: mx.array\n    min_depth: float\n    max_depth: float\n    ...\n\nclass DepthPro:\n    def __init__(self, quantize: int | None = ...) -> None: ...\n    def create_depth_map(self, image_path: str | Path) -> DepthResult: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/depth_pro_model.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass DepthProModel(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(\n        self, x0: mx.array, x1: mx.array, x2: mx.array\n    ) -> tuple[mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/depth_pro_util.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass DepthProUtil:\n    @staticmethod\n    def split(x: mx.array, overlap_ratio: float = ...) -> mx.array: ...\n    @staticmethod\n    def interpolate(x: mx.array, size=..., scale_factor=...):  # -> array:\n        ...\n    @staticmethod\n    def apply_conv(x: mx.array, conv_module: nn.Module) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/dino_v2/attention.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass Attention(nn.Module):\n    def __init__(\n        self, dim: int = ..., head_dim: int = ..., num_heads: int = ...\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/dino_v2/dino_vision_transformer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass DinoVisionTransformer(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, x: mx.array) -> tuple[mx.array, mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/dino_v2/layer_scale.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass LayerScale(nn.Module):\n    def __init__(self, dims: int, init_values: float = ...) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/dino_v2/mlp.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass MLP(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/dino_v2/patch_embed.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass PatchEmbed(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/dino_v2/transformer_block.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass TransformerBlock(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/encoder/depth_pro_encoder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass DepthProEncoder(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(\n        self, x0: mx.array, x1: mx.array, x2: mx.array\n    ) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/encoder/upsample_block.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass UpSampleBlock(nn.Module):\n    def __init__(\n        self,\n        dim_in: int = ...,\n        dim_int: int = ...,\n        dim_out: int = ...,\n        upsample_layers: int = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/model/head/fov_head.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass FOVHead(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/weights/depth_pro_weight_definition.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.weights.loading.weight_definition import (\n    ComponentDefinition,\n    TokenizerDefinition,\n)\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass DepthProWeightDefinition:\n    @staticmethod\n    def get_components() -> List[ComponentDefinition]: ...\n    @staticmethod\n    def get_tokenizers() -> List[TokenizerDefinition]: ...\n    @staticmethod\n    def get_download_patterns() -> List[str]: ...\n    @staticmethod\n    def quantization_predicate(path: str, module) -> bool: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/depth_pro/weights/depth_pro_weight_mapping.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.weights.mapping.weight_mapping import (\n    WeightMapping,\n    WeightTarget,\n)\n\nclass DepthProWeightMapping(WeightMapping):\n    @staticmethod\n    def get_mapping() -> List[WeightTarget]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/fibo/latent_creator/fibo_latent_creator.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\n\nclass FiboLatentCreator:\n    @staticmethod\n    def create_noise(seed: int, height: int, width: int) -> mx.array: ...\n    @staticmethod\n    def pack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...\n    @staticmethod\n    def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/fibo/weights/fibo_weight_definition.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.weights.loading.weight_definition import (\n    ComponentDefinition,\n    TokenizerDefinition,\n)\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass FIBOWeightDefinition:\n    @staticmethod\n    def get_components() -> List[ComponentDefinition]: ...\n    @staticmethod\n    def get_tokenizers() -> List[TokenizerDefinition]: ...\n    @staticmethod\n    def get_download_patterns() -> List[str]: ...\n    @staticmethod\n    def quantization_predicate(path: str, module) -> bool: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/fibo/weights/fibo_weight_mapping.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.weights.mapping.weight_mapping import (\n    WeightMapping,\n    WeightTarget,\n)\n\nclass FIBOWeightMapping(WeightMapping):\n    @staticmethod\n    def get_transformer_mapping() -> List[WeightTarget]: ...\n    @staticmethod\n    def get_text_encoder_mapping() -> List[WeightTarget]: ...\n    @staticmethod\n    def get_vae_mapping() -> List[WeightTarget]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/fibo_vlm/tokenizer/qwen2vl_image_processor.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.qwen.tokenizer.qwen_image_processor import QwenImageProcessor\n\nclass Qwen2VLImageProcessor(QwenImageProcessor):\n    def __init__(self) -> None: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/fibo_vlm/tokenizer/qwen2vl_processor.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Optional, Union\nfrom PIL import Image\n\nclass Qwen2VLProcessor:\n    def __init__(self, tokenizer) -> None: ...\n    def apply_chat_template(\n        self,\n        messages,\n        tokenize: bool = ...,\n        add_generation_prompt: bool = ...,\n        return_tensors: Optional[str] = ...,\n        return_dict: bool = ...,\n        **kwargs,\n    ):  # -> dict[Any, Any]:\n        ...\n    def __call__(\n        self,\n        text: Optional[Union[str, list[str]]] = ...,\n        images: Optional[Union[Image.Image, list[Image.Image]]] = ...,\n        padding: bool = ...,\n        return_tensors: Optional[str] = ...,\n        **kwargs,\n    ):  # -> dict[Any, Any]:\n        ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/fibo_vlm/weights/fibo_vlm_weight_definition.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.weights.loading.weight_definition import (\n    ComponentDefinition,\n    TokenizerDefinition,\n)\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\nQWEN2VL_CHAT_TEMPLATE = ...\n\nclass FIBOVLMWeightDefinition:\n    @staticmethod\n    def get_components() -> List[ComponentDefinition]: ...\n    @staticmethod\n    def get_tokenizers() -> List[TokenizerDefinition]: ...\n    @staticmethod\n    def get_download_patterns() -> List[str]: ...\n    @staticmethod\n    def quantization_predicate(path: str, module) -> bool: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/fibo_vlm/weights/fibo_vlm_weight_mapping.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.weights.mapping.weight_mapping import (\n    WeightMapping,\n    WeightTarget,\n)\n\nclass FIBOVLMWeightMapping(WeightMapping):\n    @staticmethod\n    def get_vlm_decoder_mapping(num_layers: int = ...) -> List[WeightTarget]: ...\n    @staticmethod\n    def get_vlm_visual_mapping(depth: int = ...) -> List[WeightTarget]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/cli/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/flux_initializer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.common.config import ModelConfig\n\nclass FluxInitializer:\n    @staticmethod\n    def init(\n        model,\n        model_config: ModelConfig,\n        quantize: int | None,\n        model_path: str | None = ...,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n        custom_transformer=...,\n    ) -> None: ...\n    @staticmethod\n    def init_depth(\n        model,\n        model_config: ModelConfig,\n        quantize: int | None,\n        model_path: str | None = ...,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n    ) -> None: ...\n    @staticmethod\n    def init_redux(\n        model,\n        model_config: ModelConfig,\n        quantize: int | None,\n        model_path: str | None = ...,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n    ) -> None: ...\n    @staticmethod\n    def init_controlnet(\n        model,\n        model_config: ModelConfig,\n        quantize: int | None,\n        model_path: str | None = ...,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n    ) -> None: ...\n    @staticmethod\n    def init_concept(\n        model,\n        model_config: ModelConfig,\n        quantize: int | None,\n        model_path: str | None = ...,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n    ) -> None: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/latent_creator/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/latent_creator/flux_latent_creator.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass FluxLatentCreator:\n    @staticmethod\n    def create_noise(seed: int, height: int, width: int) -> mx.array: ...\n    @staticmethod\n    def pack_latents(\n        latents: mx.array, height: int, width: int, num_channels_latents: int = ...\n    ) -> mx.array: ...\n    @staticmethod\n    def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/clip_encoder/clip_embeddings.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass CLIPEmbeddings(nn.Module):\n    def __init__(self, dims: int) -> None: ...\n    def __call__(self, tokens: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/clip_encoder/clip_encoder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass CLIPEncoder(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, tokens: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/clip_encoder/clip_encoder_layer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass CLIPEncoderLayer(nn.Module):\n    def __init__(self, layer: int) -> None: ...\n    def __call__(\n        self, hidden_states: mx.array, causal_attention_mask: mx.array\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/clip_encoder/clip_mlp.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass CLIPMLP(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n    @staticmethod\n    def quick_gelu(input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/clip_encoder/clip_sdpa_attention.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass CLIPSdpaAttention(nn.Module):\n    head_dimension = ...\n    batch_size = ...\n    num_heads = ...\n    def __init__(self) -> None: ...\n    def __call__(\n        self, hidden_states: mx.array, causal_attention_mask: mx.array\n    ) -> mx.array: ...\n    @staticmethod\n    def reshape_and_transpose(x, batch_size, num_heads, head_dim):  # -> array:\n        ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/clip_encoder/clip_text_model.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass CLIPTextModel(nn.Module):\n    def __init__(self, dims: int, num_encoder_layers: int) -> None: ...\n    def __call__(self, tokens: mx.array) -> tuple[mx.array, mx.array]: ...\n    @staticmethod\n    def create_causal_attention_mask(input_shape: tuple) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/clip_encoder/encoder_clip.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass EncoderCLIP(nn.Module):\n    def __init__(self, num_encoder_layers: int) -> None: ...\n    def __call__(\n        self, tokens: mx.array, causal_attention_mask: mx.array\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/prompt_encoder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mflux.models.common.tokenizer import Tokenizer\nfrom mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (\n    CLIPEncoder,\n)\nfrom mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass PromptEncoder:\n    @staticmethod\n    def encode_prompt(\n        prompt: str,\n        prompt_cache: dict[str, tuple[mx.array, mx.array]],\n        t5_tokenizer: Tokenizer,\n        clip_tokenizer: Tokenizer,\n        t5_text_encoder: T5Encoder,\n        clip_text_encoder: CLIPEncoder,\n    ) -> tuple[mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/t5_encoder/t5_attention.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass T5Attention(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/t5_encoder/t5_block.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass T5Block(nn.Module):\n    def __init__(self, layer: int) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/t5_encoder/t5_dense_relu_dense.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass T5DenseReluDense(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n    @staticmethod\n    def new_gelu(input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/t5_encoder/t5_encoder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass T5Encoder(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, tokens: mx.array): ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/t5_encoder/t5_feed_forward.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass T5FeedForward(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/t5_encoder/t5_layer_norm.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass T5LayerNorm(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_text_encoder/t5_encoder/t5_self_attention.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass T5SelfAttention(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n    @staticmethod\n    def shape(states):  # -> array:\n        ...\n    @staticmethod\n    def un_shape(states):  # -> array:\n        ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/ada_layer_norm_continuous.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass AdaLayerNormContinuous(nn.Module):\n    def __init__(self, embedding_dim: int, conditioning_embedding_dim: int) -> None: ...\n    def __call__(self, x: mx.array, text_embeddings: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/ada_layer_norm_zero.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass AdaLayerNormZero(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(\n        self, hidden_states: mx.array, text_embeddings: mx.array\n    ) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/ada_layer_norm_zero_single.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass AdaLayerNormZeroSingle(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(\n        self, hidden_states: mx.array, text_embeddings: mx.array\n    ) -> tuple[mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/common/attention_utils.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass AttentionUtils:\n    @staticmethod\n    def process_qkv(\n        hidden_states: mx.array,\n        to_q: nn.Linear,\n        to_k: nn.Linear,\n        to_v: nn.Linear,\n        norm_q: nn.RMSNorm,\n        norm_k: nn.RMSNorm,\n        num_heads: int,\n        head_dim: int,\n    ) -> tuple[mx.array, mx.array, mx.array]: ...\n    @staticmethod\n    def compute_attention(\n        query: mx.array,\n        key: mx.array,\n        value: mx.array,\n        batch_size: int,\n        num_heads: int,\n        head_dim: int,\n        mask: mx.array | None = ...,\n    ) -> mx.array: ...\n    @staticmethod\n    def convert_key_padding_mask_to_additive_mask(\n        mask: mx.array | None, joint_seq_len: int, txt_seq_len: int\n    ) -> mx.array | None: ...\n    @staticmethod\n    def apply_rope(\n        xq: mx.array, xk: mx.array, freqs_cis: mx.array\n    ) -> tuple[mx.array, mx.array]: ...\n    @staticmethod\n    def apply_rope_bshd(\n        xq: mx.array, xk: mx.array, cos: mx.array, sin: mx.array\n    ) -> tuple[mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/embed_nd.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass EmbedND(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, ids: mx.array) -> mx.array: ...\n    @staticmethod\n    def rope(pos: mx.array, dim: int, theta: float) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/feed_forward.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass FeedForward(nn.Module):\n    def __init__(self, activation_function) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/guidance_embedder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass GuidanceEmbedder(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, sample: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/joint_attention.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\nfrom typing import Any\n\nclass JointAttention(nn.Module):\n    num_heads: int\n    head_dimension: int\n    to_q: nn.Linear\n    to_k: nn.Linear\n    to_v: nn.Linear\n    norm_q: nn.RMSNorm\n    norm_k: nn.RMSNorm\n    add_q_proj: nn.Linear\n    add_k_proj: nn.Linear\n    add_v_proj: nn.Linear\n    norm_added_q: nn.RMSNorm\n    norm_added_k: nn.RMSNorm\n    to_out: list[Any]\n    to_add_out: nn.Linear\n\n    def __init__(self) -> None: ...\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        image_rotary_emb: mx.array,\n    ) -> tuple[mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/joint_transformer_block.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\nfrom typing import Any\nfrom mflux.models.flux.model.flux_transformer.joint_attention import JointAttention\nfrom mflux.models.flux.model.flux_transformer.ada_layer_norm_zero import (\n    AdaLayerNormZero,\n)\n\nclass JointTransformerBlock(nn.Module):\n    attn: JointAttention\n    norm1: AdaLayerNormZero\n    norm1_context: AdaLayerNormZero\n    norm2: nn.Module\n    norm2_context: nn.Module\n    ff: nn.Module\n    ff_context: nn.Module\n\n    def __init__(self, layer: Any) -> None: ...\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        text_embeddings: mx.array,\n        rotary_embeddings: mx.array,\n    ) -> tuple[mx.array, mx.array]: ...\n    @staticmethod\n    def apply_norm_and_feed_forward(\n        hidden_states: mx.array,\n        attn_output: mx.array,\n        gate_mlp: mx.array,\n        gate_msa: mx.array,\n        scale_mlp: mx.array,\n        shift_mlp: mx.array,\n        norm_layer: nn.Module,\n        ff_layer: nn.Module,\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/single_block_attention.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass SingleBlockAttention(nn.Module):\n    num_heads: int\n    head_dimension: int\n    to_q: nn.Linear\n    to_k: nn.Linear\n    to_v: nn.Linear\n    norm_q: nn.RMSNorm\n    norm_k: nn.RMSNorm\n\n    def __init__(self) -> None: ...\n    def __call__(\n        self, hidden_states: mx.array, image_rotary_emb: mx.array\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/single_transformer_block.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\nfrom typing import Any\nfrom mflux.models.flux.model.flux_transformer.single_block_attention import (\n    SingleBlockAttention,\n)\nfrom mflux.models.flux.model.flux_transformer.ada_layer_norm_zero_single import (\n    AdaLayerNormZeroSingle,\n)\n\nclass SingleTransformerBlock(nn.Module):\n    attn: SingleBlockAttention\n    norm: AdaLayerNormZeroSingle\n\n    def __init__(self, layer: Any) -> None: ...\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        text_embeddings: mx.array,\n        rotary_embeddings: mx.array,\n    ) -> tuple[mx.array, mx.array]: ...\n    def _apply_feed_forward_and_projection(\n        self, norm_hidden_states: mx.array, attn_output: mx.array, gate: mx.array\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/text_embedder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass TextEmbedder(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, caption: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/time_text_embed.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\nfrom mflux.models.common.config import ModelConfig\n\nclass TimeTextEmbed(nn.Module):\n    def __init__(self, model_config: ModelConfig) -> None: ...\n    def __call__(\n        self, time_step: mx.array, pooled_projection: mx.array, guidance: mx.array\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/timestep_embedder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass TimestepEmbedder(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, sample: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_transformer/transformer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.common.config.model_config import ModelConfig\nfrom mflux.models.flux.model.flux_transformer.embed_nd import EmbedND\nfrom mflux.models.flux.model.flux_transformer.time_text_embed import TimeTextEmbed\nfrom mflux.models.flux.model.flux_transformer.joint_transformer_block import (\n    JointTransformerBlock,\n)\nfrom mflux.models.flux.model.flux_transformer.single_transformer_block import (\n    SingleTransformerBlock,\n)\n\nclass Transformer(nn.Module):\n    transformer_blocks: list[JointTransformerBlock]\n    single_transformer_blocks: list[SingleTransformerBlock]\n    x_embedder: nn.Linear\n    pos_embed: EmbedND\n    time_text_embed: TimeTextEmbed\n    norm_out: nn.LayerNorm\n    proj_out: nn.Linear\n    context_embedder: nn.Linear\n\n    def __init__(\n        self,\n        model_config: ModelConfig,\n        num_transformer_blocks: int = ...,\n        num_single_transformer_blocks: int = ...,\n    ) -> None: ...\n    def __call__(\n        self,\n        t: int,\n        config: Config,\n        hidden_states: mx.array,\n        prompt_embeds: mx.array,\n        pooled_prompt_embeds: mx.array,\n        controlnet_block_samples: list[mx.array] | None = ...,\n        controlnet_single_block_samples: list[mx.array] | None = ...,\n        kontext_image_ids: mx.array | None = ...,\n    ) -> mx.array: ...\n    @staticmethod\n    def compute_rotary_embeddings(\n        prompt_embeds: mx.array,\n        pos_embed: EmbedND,\n        config: Config,\n        kontext_image_ids: mx.array | None = ...,\n    ) -> mx.array: ...\n    @staticmethod\n    def compute_text_embeddings(\n        t: int,\n        pooled_prompt_embeds: mx.array,\n        time_text_embed: TimeTextEmbed,\n        config: Config,\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/common/attention.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass Attention(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/common/resnet_block_2d.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass ResnetBlock2D(nn.Module):\n    def __init__(\n        self,\n        norm1: int,\n        conv1_in: int,\n        conv1_out: int,\n        norm2: int,\n        conv2_in: int,\n        conv2_out: int,\n        conv_shortcut_in: int | None = ...,\n        conv_shortcut_out: int | None = ...,\n        is_conv_shortcut: bool = ...,\n    ) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/common/unet_mid_block.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass UnetMidBlock(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/decoder/conv_in.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass ConvIn(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/decoder/conv_norm_out.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass ConvNormOut(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/decoder/conv_out.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass ConvOut(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/decoder/decoder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass Decoder(nn.Module):\n    def __init__(\n        self, enable_tiling: bool = ..., split_direction: str = ...\n    ) -> None: ...\n    def __call__(self, latents: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/decoder/up_block_1_or_2.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass UpBlock1Or2(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/decoder/up_block_3.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass UpBlock3(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/decoder/up_block_4.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass UpBlock4(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/decoder/up_sampler.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass UpSampler(nn.Module):\n    def __init__(self, conv_in: int, conv_out: int) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n    @staticmethod\n    def up_sample_nearest(x: mx.array, scale: int = ...):  # -> array:\n        ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/encoder/conv_in.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass ConvIn(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/encoder/conv_norm_out.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass ConvNormOut(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/encoder/conv_out.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass ConvOut(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/encoder/down_block_1.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass DownBlock1(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/encoder/down_block_2.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass DownBlock2(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/encoder/down_block_3.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass DownBlock3(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/encoder/down_block_4.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass DownBlock4(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/encoder/down_sampler.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass DownSampler(nn.Module):\n    def __init__(self, conv_in: int, conv_out: int) -> None: ...\n    def __call__(self, input_array: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/encoder/encoder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass Encoder(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, latents: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/flux_vae/vae.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass VAE(nn.Module):\n    scaling_factor: int = ...\n    shift_factor: int = ...\n    spatial_scale = ...\n    latent_channels = ...\n    def __init__(self) -> None: ...\n    def decode(self, latents: mx.array) -> mx.array: ...\n    def encode(self, image: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/redux_encoder/redux_encoder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass ReduxEncoder(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/siglip_vision_transformer/siglip_encoder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass SiglipEncoder(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, inputs_embeds: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/siglip_vision_transformer/siglip_encoder_layer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass SiglipEncoderLayer(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/siglip_vision_transformer/siglip_mlp.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass SiglipMLP(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/siglip_vision_transformer/siglip_multi_head_attention_pooling_head.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass SiglipMultiHeadAttentionPoolingHead(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/siglip_vision_transformer/siglip_sdpa_attention.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass SiglipSdpaAttention(nn.Module):\n    head_dimension = ...\n    batch_size = ...\n    num_heads = ...\n    def __init__(self) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n    @staticmethod\n    def reshape_and_transpose(x, batch_size, num_heads, head_dim):  # -> array:\n        ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/siglip_vision_transformer/siglip_vision_embeddings.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass SiglipVisionEmbeddings(nn.Module):\n    embed_dim = ...\n    image_size = ...\n    patch_size = ...\n    def __init__(self) -> None: ...\n    def __call__(self, pixel_values: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/model/siglip_vision_transformer/siglip_vision_transformer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass SiglipVisionTransformer(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, pixel_values: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/variants/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/variants/concept_attention/attention_data.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport PIL.Image\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import List\nfrom mflux.models.flux.variants.concept_attention.joint_transformer_block_concept import (\n    LayerAttentionData,\n)\n\n@dataclass\nclass TimestepAttentionData:\n    t: int\n    attention_information: List[LayerAttentionData]\n    def stack_img_attentions(self) -> mx.array: ...\n    def stack_concept_attentions(self) -> mx.array: ...\n\nclass GenerationAttentionData:\n    def __init__(self) -> None: ...\n    def append(self, timestep_attention: TimestepAttentionData):  # -> None:\n        ...\n    def stack_all_img_attentions(self) -> mx.array: ...\n    def stack_all_concept_attentions(self) -> mx.array: ...\n\n@dataclass\nclass ConceptHeatmap:\n    concept: str\n    image: PIL.Image.Image\n    layer_indices: List[int]\n    timesteps: List[int]\n    height: int\n    width: int\n    def save(\n        self, path: str | Path, export_json_metadata: bool = ..., overwrite: bool = ...\n    ) -> None: ...\n    def get_metadata(self) -> dict: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/variants/concept_attention/joint_attention_concept.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass JointAttentionConcept(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        encoder_hidden_states_concept: mx.array,\n        image_rotary_emb: mx.array,\n        image_rotary_emb_concept: mx.array,\n    ) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/variants/concept_attention/joint_transformer_block_concept.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom dataclasses import dataclass\nfrom mlx import nn\n\n@dataclass\nclass LayerAttentionData:\n    layer: int\n    img_attention: mx.array\n    concept_attention: mx.array\n    ...\n\nclass JointTransformerBlockConcept(nn.Module):\n    def __init__(self, layer) -> None: ...\n    def __call__(\n        self,\n        layer_idx: int,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        encoder_hidden_states_concept: mx.array,\n        text_embeddings: mx.array,\n        text_embeddings_concept: mx.array,\n        rotary_embeddings: mx.array,\n        rotary_embeddings_concept: mx.array,\n    ) -> tuple[mx.array, mx.array, mx.array, LayerAttentionData]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/variants/concept_attention/transformer_concept.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.common.config.model_config import ModelConfig\nfrom mflux.models.flux.variants.concept_attention.attention_data import (\n    TimestepAttentionData,\n)\n\nclass TransformerConcept(nn.Module):\n    def __init__(\n        self,\n        model_config: ModelConfig,\n        num_transformer_blocks: int = ...,\n        num_single_transformer_blocks: int = ...,\n    ) -> None: ...\n    def __call__(\n        self,\n        t: int,\n        config: Config,\n        hidden_states: mx.array,\n        prompt_embeds: mx.array,\n        prompt_embeds_concept: mx.array,\n        pooled_prompt_embeds: mx.array,\n        pooled_prompt_embeds_concept: mx.array,\n    ) -> tuple[mx.array, TimestepAttentionData]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/variants/controlnet/transformer_controlnet.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.common.config.model_config import ModelConfig\n\nclass TransformerControlnet(nn.Module):\n    def __init__(\n        self,\n        model_config: ModelConfig,\n        num_transformer_blocks: int = ...,\n        num_single_transformer_blocks: int = ...,\n    ) -> None: ...\n    def __call__(\n        self,\n        t: int,\n        config: Config,\n        hidden_states: mx.array,\n        prompt_embeds: mx.array,\n        pooled_prompt_embeds: mx.array,\n        controlnet_condition: mx.array,\n    ) -> tuple[list[mx.array], list[mx.array]]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/variants/kontext/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext\n\n__all__ = [\"Flux1Kontext\"]\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/variants/kontext/flux_kontext.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom pathlib import Path\nfrom typing import Any\n\nfrom mlx import nn\n\nfrom mflux.models.common.config.model_config import ModelConfig\nfrom mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (\n    CLIPEncoder,\n)\nfrom mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder\nfrom mflux.models.flux.model.flux_transformer.transformer import Transformer\nfrom mflux.models.flux.model.flux_vae.vae import VAE\nfrom mflux.utils.generated_image import GeneratedImage\n\nclass Flux1Kontext(nn.Module):\n    vae: VAE\n    transformer: Transformer\n    t5_text_encoder: T5Encoder\n    clip_text_encoder: CLIPEncoder\n    bits: int | None\n    lora_paths: list[str] | None\n    lora_scales: list[float] | None\n    prompt_cache: dict[str, Any]\n    tokenizers: dict[str, Any]\n\n    def __init__(\n        self,\n        quantize: int | None = ...,\n        model_path: str | None = ...,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n        model_config: ModelConfig = ...,\n    ) -> None: ...\n    def generate_image(\n        self,\n        seed: int,\n        prompt: str,\n        num_inference_steps: int = ...,\n        height: int = ...,\n        width: int = ...,\n        guidance: float = ...,\n        image_path: Path | str | None = ...,\n        image_strength: float | None = ...,\n        scheduler: str = ...,\n    ) -> GeneratedImage: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/variants/kontext/kontext_util.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\n\nfrom mflux.models.flux.model.flux_vae.vae import VAE\n\nclass KontextUtil:\n    @staticmethod\n    def create_image_conditioning_latents(\n        vae: VAE,\n        height: int,\n        width: int,\n        image_path: str,\n    ) -> tuple[mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/variants/txt2img/flux.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom pathlib import Path\nfrom mlx import nn\nfrom typing import Any\nfrom mflux.models.common.config.model_config import ModelConfig\nfrom mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (\n    CLIPEncoder,\n)\nfrom mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder\nfrom mflux.models.flux.model.flux_transformer.transformer import Transformer\nfrom mflux.models.flux.model.flux_vae.vae import VAE\nfrom mflux.utils.generated_image import GeneratedImage\n\nclass Flux1(nn.Module):\n    vae: VAE\n    transformer: Transformer\n    t5_text_encoder: T5Encoder\n    clip_text_encoder: CLIPEncoder\n    bits: int | None\n    lora_paths: list[str] | None\n    lora_scales: list[float] | None\n    prompt_cache: dict[str, Any]\n    tokenizers: dict[str, Any]\n\n    def __init__(\n        self,\n        quantize: int | None = ...,\n        model_path: str | None = ...,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n        model_config: ModelConfig = ...,\n    ) -> None: ...\n    def generate_image(\n        self,\n        seed: int,\n        prompt: str,\n        num_inference_steps: int = ...,\n        height: int = ...,\n        width: int = ...,\n        guidance: float = ...,\n        image_path: Path | str | None = ...,\n        image_strength: float | None = ...,\n        scheduler: str = ...,\n        negative_prompt: str | None = ...,\n    ) -> GeneratedImage: ...\n    @staticmethod\n    def from_name(model_name: str, quantize: int | None = ...) -> Flux1: ...\n    def save_model(self, base_path: str) -> None: ...\n    def freeze(self, **kwargs):  # -> None:\n        ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/weights/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.flux.weights.flux_weight_definition import FluxWeightDefinition\nfrom mflux.models.flux.weights.flux_weight_mapping import FluxWeightMapping\n\n__all__ = [\"FluxWeightDefinition\", \"FluxWeightMapping\"]\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/weights/flux_lora_mapping.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.common.lora.mapping.lora_mapping import LoRAMapping, LoRATarget\n\nclass FluxLoRAMapping(LoRAMapping):\n    @staticmethod\n    def get_mapping() -> list[LoRATarget]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/weights/flux_weight_definition.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.weights.loading.weight_definition import (\n    ComponentDefinition,\n    TokenizerDefinition,\n)\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass FluxWeightDefinition:\n    @staticmethod\n    def get_components() -> List[ComponentDefinition]: ...\n    @staticmethod\n    def get_tokenizers() -> List[TokenizerDefinition]: ...\n    @staticmethod\n    def get_download_patterns() -> List[str]: ...\n    @staticmethod\n    def quantization_predicate(path: str, module) -> bool: ...\n\nclass FluxControlnetWeightDefinition:\n    @staticmethod\n    def get_controlnet_component() -> ComponentDefinition: ...\n    @staticmethod\n    def get_components() -> List[ComponentDefinition]: ...\n    @staticmethod\n    def get_tokenizers() -> List[TokenizerDefinition]: ...\n    @staticmethod\n    def get_download_patterns() -> List[str]: ...\n    @staticmethod\n    def quantization_predicate(path: str, module) -> bool: ...\n\nclass FluxReduxWeightDefinition:\n    @staticmethod\n    def get_components() -> List[ComponentDefinition]: ...\n    @staticmethod\n    def get_tokenizers() -> List[TokenizerDefinition]: ...\n    @staticmethod\n    def get_download_patterns() -> List[str]: ...\n    @staticmethod\n    def quantization_predicate(path: str, module) -> bool: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/flux/weights/flux_weight_mapping.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.weights.mapping.weight_mapping import (\n    WeightMapping,\n    WeightTarget,\n)\n\nclass FluxWeightMapping(WeightMapping):\n    @staticmethod\n    def get_transformer_mapping() -> List[WeightTarget]: ...\n    @staticmethod\n    def get_controlnet_transformer_mapping() -> List[WeightTarget]: ...\n    @staticmethod\n    def get_vae_mapping() -> List[WeightTarget]: ...\n    @staticmethod\n    def get_t5_encoder_mapping() -> List[WeightTarget]: ...\n    @staticmethod\n    def get_clip_encoder_mapping() -> List[WeightTarget]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/cli/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/latent_creator/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/latent_creator/qwen_latent_creator.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass QwenLatentCreator:\n    @staticmethod\n    def create_noise(seed: int, height: int, width: int) -> mx.array: ...\n    @staticmethod\n    def pack_latents(\n        latents: mx.array, height: int, width: int, num_channels_latents: int = ...\n    ) -> mx.array: ...\n    @staticmethod\n    def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_attention.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenAttention(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        num_attention_heads: int,\n        num_key_value_heads: int = ...,\n        max_position_embeddings: int = ...,\n        rope_theta: float = ...,\n        rope_scaling: dict = ...,\n    ) -> None: ...\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        attention_mask: mx.array | None = ...,\n        position_embeddings: tuple[mx.array, mx.array] | None = ...,\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_encoder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenEncoder(nn.Module):\n    def __init__(\n        self,\n        vocab_size: int = ...,\n        hidden_size: int = ...,\n        num_hidden_layers: int = ...,\n        max_position_embeddings: int = ...,\n        rope_theta: float = ...,\n    ) -> None: ...\n    def get_image_features(\n        self, pixel_values: mx.array, image_grid_thw: mx.array\n    ) -> mx.array: ...\n    def __call__(\n        self,\n        input_ids: mx.array,\n        attention_mask: mx.array,\n        pixel_values: mx.array | None = ...,\n        image_grid_thw: mx.array | None = ...,\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_encoder_layer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenEncoderLayer(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int = ...,\n        num_attention_heads: int = ...,\n        num_key_value_heads: int = ...,\n        intermediate_size: int = ...,\n        rms_norm_eps: float = ...,\n        max_position_embeddings: int = ...,\n        rope_theta: float = ...,\n    ) -> None: ...\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        attention_mask: mx.array | None = ...,\n        position_embeddings: tuple[mx.array, mx.array] | None = ...,\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_mlp.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenMLP(nn.Module):\n    def __init__(self, hidden_size: int, intermediate_size: int) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_patch_merger.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass PatchMerger(nn.Module):\n    def __init__(\n        self, context_dim: int, hidden_size: int, spatial_merge_size: int = ...\n    ) -> None: ...\n    def __call__(self, x: mx.array, grid_thw: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_prompt_encoder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mflux.models.common.tokenizer import Tokenizer\nfrom mflux.models.qwen.model.qwen_text_encoder.qwen_text_encoder import QwenTextEncoder\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass QwenPromptEncoder:\n    @staticmethod\n    def encode_prompt(\n        prompt: str,\n        negative_prompt: str,\n        prompt_cache: dict[str, tuple[mx.array, mx.array, mx.array, mx.array]],\n        qwen_tokenizer: Tokenizer,\n        qwen_text_encoder: QwenTextEncoder,\n    ) -> tuple[mx.array, mx.array, mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_rms_norm.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenRMSNorm(nn.Module):\n    def __init__(self, hidden_size: int, eps: float = ...) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_rope.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenRotaryEmbedding(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        max_position_embeddings: int = ...,\n        base: float = ...,\n        device: str = ...,\n        scaling_factor: float = ...,\n        rope_type: str = ...,\n        config=...,\n    ) -> None: ...\n    def __call__(\n        self, x: mx.array, position_ids: mx.array\n    ) -> tuple[mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_text_encoder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass QwenTextEncoder(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(\n        self, input_ids: mx.array, attention_mask: mx.array\n    ) -> tuple[mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_vision_attention.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass VisionAttention(nn.Module):\n    def __init__(self, embed_dim: int = ..., num_heads: int = ...) -> None: ...\n    def __call__(\n        self, x: mx.array, position_embeddings=..., cu_seqlens=...\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_vision_block.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass VisionBlock(nn.Module):\n    def __init__(\n        self, embed_dim: int = ..., num_heads: int = ..., mlp_ratio: float = ...\n    ) -> None: ...\n    def __call__(\n        self, x: mx.array, position_embeddings=..., cu_seqlens=...\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_vision_language_encoder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenVisionLanguageEncoder(nn.Module):\n    def __init__(self, encoder=...) -> None: ...\n    def __call__(\n        self,\n        input_ids: mx.array,\n        attention_mask: mx.array | None = ...,\n        pixel_values: mx.array | None = ...,\n        image_grid_thw: mx.array | None = ...,\n    ) -> tuple[mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_vision_mlp.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass VisionMLP(nn.Module):\n    def __init__(self, dim: int, hidden_dim: int) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_vision_patch_embed.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass VisionPatchEmbed(nn.Module):\n    def __init__(\n        self,\n        patch_size: int = ...,\n        temporal_patch_size: int = ...,\n        in_channels: int = ...,\n        embed_dim: int = ...,\n    ) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_vision_rotary_embedding.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass VisionRotaryEmbedding(nn.Module):\n    def __init__(self, dim: int, theta: float = ...) -> None: ...\n    def __call__(self, max_grid_size: int) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_text_encoder/qwen_vision_transformer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass VisionTransformer(nn.Module):\n    def __init__(\n        self,\n        patch_size: int = ...,\n        temporal_patch_size: int = ...,\n        in_channels: int = ...,\n        embed_dim: int = ...,\n        depth: int = ...,\n        num_heads: int = ...,\n        mlp_ratio: float = ...,\n        hidden_size: int = ...,\n        spatial_merge_size: int = ...,\n        window_size: int = ...,\n        fullatt_block_indexes: list = ...,\n    ) -> None: ...\n    def get_window_index(self, grid_thw: mx.array):  # -> tuple[array, array]:\n        ...\n    def rot_pos_emb(self, grid_thw: mx.array) -> mx.array: ...\n    def __call__(self, pixel_values: mx.array, grid_thw: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_transformer/qwen_attention.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\nfrom typing import Any\n\nclass QwenAttention(nn.Module):\n    _num_heads: int\n    _head_dim: int\n    num_heads: int\n    head_dim: int\n    to_q: nn.Linear\n    to_k: nn.Linear\n    to_v: nn.Linear\n    add_q_proj: nn.Linear\n    add_k_proj: nn.Linear\n    add_v_proj: nn.Linear\n    norm_q: nn.RMSNorm\n    norm_k: nn.RMSNorm\n    norm_added_q: nn.RMSNorm\n    norm_added_k: nn.RMSNorm\n    attn_to_out: list[Any]\n    to_add_out: nn.Linear\n\n    def __init__(\n        self, dim: int = ..., num_heads: int = ..., head_dim: int = ...\n    ) -> None: ...\n    def __call__(\n        self,\n        img_modulated: mx.array,\n        txt_modulated: mx.array,\n        encoder_hidden_states_mask: mx.array | None,\n        image_rotary_emb: tuple[mx.array, mx.array],\n        block_idx: int | None = ...,\n    ) -> tuple[mx.array, mx.array]: ...\n    def _compute_attention_qwen(\n        self,\n        query: mx.array,\n        key: mx.array,\n        value: mx.array,\n        mask: mx.array | None,\n        block_idx: int | None,\n    ) -> mx.array: ...\n    @staticmethod\n    def _convert_mask_for_qwen(\n        mask: mx.array | None, joint_seq_len: int, txt_seq_len: int\n    ) -> mx.array | None: ...\n    @staticmethod\n    def _apply_rope_qwen(\n        x: mx.array, cos_vals: mx.array, sin_vals: mx.array\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_transformer/qwen_feed_forward.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenFeedForward(nn.Module):\n    def __init__(self, dim: int = ...) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_transformer/qwen_rope.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenEmbedRopeMLX(nn.Module):\n    def __init__(\n        self, theta: int, axes_dim: list[int], scale_rope: bool = ...\n    ) -> None: ...\n    def __call__(\n        self,\n        video_fhw: tuple[int, int, int] | list[tuple[int, int, int]],\n        txt_seq_lens: list[int],\n    ) -> tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_transformer/qwen_time_text_embed.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenTimeTextEmbed(nn.Module):\n    def __init__(self, timestep_proj_dim: int = ..., inner_dim: int = ...) -> None: ...\n    def __call__(self, timestep: mx.array, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_transformer/qwen_timestep_embedding.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenTimestepEmbedding(nn.Module):\n    def __init__(self, proj_dim: int, inner_dim: int) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_transformer/qwen_timesteps.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenTimesteps(nn.Module):\n    def __init__(self, proj_dim: int = ..., scale: float = ...) -> None: ...\n    def __call__(self, timesteps: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_transformer/qwen_transformer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\nfrom typing import Any\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (\n    QwenTransformerBlock,\n)\n\nclass QwenTransformer(nn.Module):\n    transformer_blocks: list[QwenTransformerBlock]\n    inner_dim: int\n    img_in: nn.Linear\n    txt_in: nn.Linear\n    txt_norm: nn.RMSNorm\n    time_text_embed: Any\n    pos_embed: Any\n    norm_out: nn.Module\n    proj_out: nn.Linear\n\n    def __init__(\n        self,\n        in_channels: int = ...,\n        out_channels: int = ...,\n        num_layers: int = ...,\n        attention_head_dim: int = ...,\n        num_attention_heads: int = ...,\n        joint_attention_dim: int = ...,\n        patch_size: int = ...,\n    ) -> None: ...\n    def __call__(\n        self,\n        t: int,\n        config: Config,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        encoder_hidden_states_mask: mx.array,\n        qwen_image_ids: mx.array | None = ...,\n        cond_image_grid: tuple[int, int, int] | None = ...,\n    ) -> mx.array: ...\n    @staticmethod\n    def _compute_timestep(t: int | float, config: Config) -> mx.array: ...\n    @staticmethod\n    def _compute_rotary_embeddings(\n        encoder_hidden_states_mask: mx.array,\n        pos_embed: Any,\n        config: Config,\n        cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] | None = ...,\n    ) -> tuple[mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_transformer/qwen_transformer_block.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\nfrom typing import Any\nfrom mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention\n\nclass QwenTransformerBlock(nn.Module):\n    attn: QwenAttention\n    img_mod_linear: nn.Linear\n    img_mod_silu: nn.SiLU\n    txt_mod_linear: nn.Linear\n    txt_mod_silu: nn.SiLU\n    img_norm1: nn.RMSNorm\n    txt_norm1: nn.RMSNorm\n    img_norm2: nn.RMSNorm\n    txt_norm2: nn.RMSNorm\n    img_ff: Any\n    txt_ff: Any\n\n    def __init__(\n        self, dim: int = ..., num_heads: int = ..., head_dim: int = ...\n    ) -> None: ...\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        encoder_hidden_states_mask: mx.array | None,\n        text_embeddings: mx.array,\n        image_rotary_emb: tuple[mx.array, mx.array],\n        block_idx: int | None = ...,\n    ) -> tuple[mx.array, mx.array]: ...\n    @staticmethod\n    def _modulate(x: mx.array, mod_params: mx.array) -> tuple[mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_transformer/qwen_transformer_rms_norm.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenTransformerRMSNorm(nn.Module):\n    def __init__(self, dim: int, eps: float = ...) -> None: ...\n    def __call__(self, hidden_states: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_vae/qwen_image_attention_block_3d.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenImageAttentionBlock3D(nn.Module):\n    def __init__(self, dim: int) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_vae/qwen_image_causal_conv_3d.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenImageCausalConv3D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int = ...,\n        stride: int = ...,\n        padding: int = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_vae/qwen_image_decoder_3d.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenImageDecoder3D(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_vae/qwen_image_down_block_3d.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenImageDownBlock3D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        num_res_blocks: int = ...,\n        downsample_mode: str = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_vae/qwen_image_encoder_3d.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenImageEncoder3D(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_vae/qwen_image_mid_block_3d.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenImageMidBlock3D(nn.Module):\n    def __init__(self, dim: int, num_layers: int = ...) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_vae/qwen_image_res_block_3d.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenImageResBlock3D(nn.Module):\n    def __init__(self, in_channels: int, out_channels: int) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_vae/qwen_image_resample_3d.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenImageResample3D(nn.Module):\n    def __init__(self, dim: int, mode: str) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_vae/qwen_image_rms_norm.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenImageRMSNorm(nn.Module):\n    def __init__(\n        self, num_channels: int, eps: float = ..., images: bool = ...\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_vae/qwen_image_up_block_3d.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenImageUpBlock3D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        num_res_blocks: int = ...,\n        upsample_mode: str = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/model/qwen_vae/qwen_vae.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mlx import nn\n\nclass QwenVAE(nn.Module):\n    LATENTS_MEAN = ...\n    LATENTS_STD = ...\n    spatial_scale = ...\n    latent_channels = ...\n    def __init__(self) -> None: ...\n    def decode(self, latents: mx.array) -> mx.array: ...\n    def encode(self, latents: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/qwen_initializer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.common.config import ModelConfig\n\nclass QwenImageInitializer:\n    @staticmethod\n    def init(\n        model,\n        model_config: ModelConfig,\n        quantize: int | None,\n        model_path: str | None = ...,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n    ) -> None: ...\n    @staticmethod\n    def init_edit(\n        model,\n        model_config: ModelConfig,\n        quantize: int | None,\n        model_path: str | None = ...,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n    ) -> None: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/tokenizer/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/tokenizer/qwen_image_processor.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport numpy as np\nfrom typing import Optional, Union\nfrom PIL import Image\n\nOPENAI_CLIP_MEAN = ...\nOPENAI_CLIP_STD = ...\n\ndef smart_resize(\n    height: int,\n    width: int,\n    factor: int = ...,\n    min_pixels: int = ...,\n    max_pixels: int = ...,\n) -> tuple[int, int]: ...\n\nclass QwenImageProcessor:\n    def __init__(\n        self,\n        min_pixels: int = ...,\n        max_pixels: int = ...,\n        patch_size: int = ...,\n        temporal_patch_size: int = ...,\n        merge_size: int = ...,\n        image_mean: Optional[list[float]] = ...,\n        image_std: Optional[list[float]] = ...,\n    ) -> None: ...\n    def preprocess(\n        self, images: Union[Image.Image, list[Image.Image]]\n    ) -> tuple[np.ndarray, np.ndarray]: ...\n    def get_number_of_image_patches(\n        self,\n        height: int,\n        width: int,\n        min_pixels: Optional[int] = ...,\n        max_pixels: Optional[int] = ...,\n    ) -> int: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/tokenizer/qwen_vision_language_processor.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Optional, Union\nfrom PIL import Image\nfrom mflux.models.qwen.tokenizer.qwen_image_processor import QwenImageProcessor\n\nclass QwenVisionLanguageProcessor:\n    def __init__(\n        self,\n        tokenizer,\n        image_processor: Optional[QwenImageProcessor] = ...,\n        image_token: str = ...,\n        video_token: str = ...,\n    ) -> None: ...\n    def __call__(\n        self,\n        images: Optional[Union[Image.Image, list[Image.Image]]] = ...,\n        text: Optional[Union[str, list[str]]] = ...,\n        padding: bool = ...,\n        return_tensors: Optional[str] = ...,\n    ) -> dict: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/tokenizer/qwen_vision_language_tokenizer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport numpy as np\nfrom typing import Union\nfrom PIL import Image\nfrom mflux.models.qwen.tokenizer.qwen_vision_language_processor import (\n    QwenVisionLanguageProcessor,\n)\n\nclass QwenVisionLanguageTokenizer:\n    def __init__(\n        self,\n        processor: QwenVisionLanguageProcessor,\n        max_length: int = ...,\n        use_picture_prefix: bool = ...,\n    ) -> None: ...\n    def tokenize_with_image(\n        self,\n        prompt: str,\n        image: Union[Image.Image, np.ndarray, str, list],\n        vl_width: int | None = ...,\n        vl_height: int | None = ...,\n    ) -> tuple[mx.array, mx.array, mx.array, mx.array]: ...\n    def tokenize_text_only(self, prompt: str) -> tuple[mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/variants/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/variants/edit/qwen_edit_util.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom mflux.models.common.vae.tiling_config import TilingConfig\n\nclass QwenEditUtil:\n    @staticmethod\n    def create_image_conditioning_latents(\n        vae,\n        height: int,\n        width: int,\n        image_paths: list[str] | str,\n        vl_width: int | None = ...,\n        vl_height: int | None = ...,\n        tiling_config: TilingConfig | None = ...,\n    ) -> tuple[mx.array, mx.array, int, int, int]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/variants/edit/qwen_image_edit.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom pathlib import Path\nfrom mlx import nn\nfrom typing import Any\nfrom mflux.models.common.config import Config\nfrom mflux.models.common.config.model_config import ModelConfig\nfrom mflux.models.qwen.model.qwen_text_encoder.qwen_text_encoder import QwenTextEncoder\nfrom mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer\nfrom mflux.models.qwen.model.qwen_vae.qwen_vae import QwenVAE\nfrom mflux.utils.generated_image import GeneratedImage\n\nclass QwenImageEdit(nn.Module):\n    vae: QwenVAE\n    transformer: QwenTransformer\n    text_encoder: QwenTextEncoder\n    bits: int | None\n    lora_paths: list[str] | None\n    lora_scales: list[float] | None\n    prompt_cache: dict[str, Any]\n    tokenizers: dict[str, Any]\n\n    def __init__(\n        self,\n        quantize: int | None = ...,\n        model_path: str | None = ...,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n        model_config: ModelConfig = ...,\n    ) -> None: ...\n    def generate_image(\n        self,\n        seed: int,\n        prompt: str,\n        image_paths: list[str],\n        num_inference_steps: int = ...,\n        height: int | None = ...,\n        width: int | None = ...,\n        guidance: float = ...,\n        image_path: Path | str | None = ...,\n        scheduler: str = ...,\n        negative_prompt: str | None = ...,\n    ) -> GeneratedImage: ...\n    def _encode_prompts_with_images(\n        self,\n        prompt: str,\n        negative_prompt: str,\n        image_paths: list[str],\n        config: Config,\n        vl_width: int | None,\n        vl_height: int | None,\n    ) -> tuple[mx.array, mx.array, mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/variants/txt2img/qwen_image.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom pathlib import Path\nfrom mlx import nn\nfrom typing import Any\nfrom mflux.models.common.config import ModelConfig\nfrom mflux.models.qwen.model.qwen_text_encoder.qwen_text_encoder import QwenTextEncoder\nfrom mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer\nfrom mflux.models.qwen.model.qwen_vae.qwen_vae import QwenVAE\nfrom mflux.utils.generated_image import GeneratedImage\n\nclass QwenImage(nn.Module):\n    vae: QwenVAE\n    transformer: QwenTransformer\n    text_encoder: QwenTextEncoder\n    bits: int | None\n    lora_paths: list[str] | None\n    lora_scales: list[float] | None\n    prompt_cache: dict[str, Any]\n    tokenizers: dict[str, Any]\n\n    def __init__(\n        self,\n        quantize: int | None = ...,\n        model_path: str | None = ...,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n        model_config: ModelConfig = ...,\n    ) -> None: ...\n    def generate_image(\n        self,\n        seed: int,\n        prompt: str,\n        num_inference_steps: int = ...,\n        height: int = ...,\n        width: int = ...,\n        guidance: float = ...,\n        image_path: Path | str | None = ...,\n        image_strength: float | None = ...,\n        scheduler: str = ...,\n        negative_prompt: str | None = ...,\n    ) -> GeneratedImage: ...\n    def save_model(self, base_path: str) -> None: ...\n    @staticmethod\n    def compute_guided_noise(\n        noise: mx.array, noise_negative: mx.array, guidance: float\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/weights/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.qwen.weights.qwen_weight_definition import QwenWeightDefinition\nfrom mflux.models.qwen.weights.qwen_weight_mapping import QwenWeightMapping\n\n__all__ = [\"QwenWeightDefinition\", \"QwenWeightMapping\"]\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/weights/qwen_lora_mapping.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.lora.mapping.lora_mapping import LoRAMapping, LoRATarget\n\nclass QwenLoRAMapping(LoRAMapping):\n    @staticmethod\n    def get_mapping() -> List[LoRATarget]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/weights/qwen_weight_definition.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.weights.loading.weight_definition import (\n    ComponentDefinition,\n    TokenizerDefinition,\n)\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass QwenWeightDefinition:\n    @staticmethod\n    def get_components() -> List[ComponentDefinition]: ...\n    @staticmethod\n    def get_tokenizers() -> List[TokenizerDefinition]: ...\n    @staticmethod\n    def get_download_patterns() -> List[str]: ...\n    @staticmethod\n    def quantization_predicate(path: str, module) -> bool: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/qwen/weights/qwen_weight_mapping.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.weights.mapping.weight_mapping import (\n    WeightMapping,\n    WeightTarget,\n)\n\nclass QwenWeightMapping(WeightMapping):\n    @staticmethod\n    def get_transformer_mapping() -> List[WeightTarget]: ...\n    @staticmethod\n    def get_vae_mapping() -> List[WeightTarget]: ...\n    @staticmethod\n    def get_text_encoder_mapping() -> List[WeightTarget]: ...\n    @staticmethod\n    def get_mapping() -> List[WeightTarget]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/seedvr2/weights/seedvr2_weight_definition.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.weights.loading.weight_definition import (\n    ComponentDefinition,\n    TokenizerDefinition,\n)\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass SeedVR2WeightDefinition:\n    @staticmethod\n    def get_components() -> List[ComponentDefinition]: ...\n    @staticmethod\n    def get_tokenizers() -> List[TokenizerDefinition]: ...\n    @staticmethod\n    def get_download_patterns() -> List[str]: ...\n    @staticmethod\n    def quantization_predicate(path: str, module) -> bool: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/seedvr2/weights/seedvr2_weight_mapping.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.weights.mapping.weight_mapping import (\n    WeightMapping,\n    WeightTarget,\n)\n\nclass SeedVR2WeightMapping(WeightMapping):\n    @staticmethod\n    def get_transformer_mapping() -> List[WeightTarget]: ...\n    @staticmethod\n    def get_vae_mapping() -> List[WeightTarget]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/z_image/latent_creator/z_image_latent_creator.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\n\nclass ZImageLatentCreator:\n    @staticmethod\n    def create_noise(seed: int, height: int, width: int) -> mx.array: ...\n    @staticmethod\n    def pack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...\n    @staticmethod\n    def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/z_image/weights/z_image_weight_definition.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import List\nfrom mflux.models.common.weights.loading.weight_definition import (\n    ComponentDefinition,\n    TokenizerDefinition,\n)\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass ZImageWeightDefinition:\n    @staticmethod\n    def get_components() -> List[ComponentDefinition]: ...\n    @staticmethod\n    def get_tokenizers() -> List[TokenizerDefinition]: ...\n    @staticmethod\n    def get_download_patterns() -> List[str]: ...\n    @staticmethod\n    def quantization_predicate(path: str, module) -> bool: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/models/z_image/weights/z_image_weight_mapping.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom mflux.models.common.weights.mapping.weight_mapping import (\n    WeightMapping,\n    WeightTarget,\n)\n\nclass ZImageWeightMapping(WeightMapping):\n    @staticmethod\n    def get_text_encoder_mapping() -> list[WeightTarget]: ...\n    @staticmethod\n    def get_vae_mapping() -> list[WeightTarget]: ...\n    @staticmethod\n    def get_transformer_mapping() -> list[WeightTarget]: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/release/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/utils/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mflux/utils/box_values.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom dataclasses import dataclass\n\n@dataclass\nclass AbsoluteBoxValues:\n    top: int\n    right: int\n    bottom: int\n    left: int\n    ...\n\nclass BoxValueError(ValueError): ...\n\n@dataclass\nclass BoxValues:\n    top: int | str\n    right: int | str\n    bottom: int | str\n    left: int | str\n    def normalize_to_dimensions(self, width, height) -> AbsoluteBoxValues: ...\n    @staticmethod\n    def parse(value, delimiter=...) -> BoxValues: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/utils/exceptions.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass MFluxException(Exception): ...\nclass ImageSavingException(MFluxException): ...\nclass MetadataEmbedException(MFluxException): ...\nclass MFluxUserException(MFluxException): ...\nclass PromptFileReadError(MFluxUserException): ...\nclass StopImageGenerationException(MFluxUserException): ...\nclass StopTrainingException(MFluxUserException): ...\n\nclass CommandExecutionError(MFluxException):\n    def __init__(\n        self, cmd: list[str], return_code: int, stdout: str | None, stderr: str | None\n    ) -> None: ...\n\nclass ReferenceVsOutputImageError(AssertionError): ...\nclass ModelConfigError(ValueError): ...\nclass InvalidBaseModel(ModelConfigError): ...\n"
  },
  {
    "path": ".mlx_typings/mflux/utils/generated_image.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport PIL.Image\nfrom pathlib import Path\nfrom mflux.models.common.config import ModelConfig\nfrom mflux.models.flux.variants.concept_attention.attention_data import ConceptHeatmap\n\nlog = ...\n\nclass GeneratedImage:\n    image: PIL.Image.Image\n\n    def __init__(\n        self,\n        image: PIL.Image.Image,\n        model_config: ModelConfig,\n        seed: int,\n        prompt: str,\n        steps: int,\n        guidance: float | None,\n        precision: mx.Dtype,\n        quantization: int,\n        generation_time: float,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n        height: int | None = ...,\n        width: int | None = ...,\n        controlnet_image_path: str | Path | None = ...,\n        controlnet_strength: float | None = ...,\n        image_path: str | Path | None = ...,\n        image_paths: list[str] | list[Path] | None = ...,\n        image_strength: float | None = ...,\n        masked_image_path: str | Path | None = ...,\n        depth_image_path: str | Path | None = ...,\n        redux_image_paths: list[str] | list[Path] | None = ...,\n        redux_image_strengths: list[float] | None = ...,\n        concept_heatmap: ConceptHeatmap | None = ...,\n        negative_prompt: str | None = ...,\n        init_metadata: dict | None = ...,\n    ) -> None: ...\n    def get_right_half(self) -> GeneratedImage: ...\n    def save(\n        self, path: str | Path, export_json_metadata: bool = ..., overwrite: bool = ...\n    ) -> None: ...\n    def save_with_heatmap(\n        self, path: str | Path, export_json_metadata: bool = ..., overwrite: bool = ...\n    ) -> None: ...\n    def save_concept_heatmap(\n        self, path: str | Path, export_json_metadata: bool = ..., overwrite: bool = ...\n    ) -> None: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/utils/image_util.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nimport PIL.Image\nfrom pathlib import Path\nfrom typing import Any\nfrom PIL._typing import StrOrBytesPath\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.flux.variants.concept_attention.attention_data import ConceptHeatmap\nfrom mflux.utils.box_values import AbsoluteBoxValues\nfrom mflux.utils.generated_image import GeneratedImage\n\nlog = ...\n\nclass ImageUtil:\n    @staticmethod\n    def to_image(\n        decoded_latents: mx.array,\n        config: Config,\n        seed: int,\n        prompt: str,\n        quantization: int,\n        generation_time: float,\n        lora_paths: list[str] | None = ...,\n        lora_scales: list[float] | None = ...,\n        controlnet_image_path: str | Path | None = ...,\n        image_path: str | Path | None = ...,\n        image_paths: list[str] | list[Path] | None = ...,\n        redux_image_paths: list[str] | list[Path] | None = ...,\n        redux_image_strengths: list[float] | None = ...,\n        image_strength: float | None = ...,\n        masked_image_path: str | Path | None = ...,\n        depth_image_path: str | Path | None = ...,\n        concept_heatmap: ConceptHeatmap | None = ...,\n        negative_prompt: str | None = ...,\n        init_metadata: dict[str, Any] | None = ...,\n    ) -> GeneratedImage: ...\n    @staticmethod\n    def to_composite_image(\n        generated_images: list[GeneratedImage],\n    ) -> PIL.Image.Image: ...\n    @staticmethod\n    def to_array(image: PIL.Image.Image, is_mask: bool = ...) -> mx.array: ...\n    @staticmethod\n    def load_image(\n        image_or_path: PIL.Image.Image | StrOrBytesPath,\n    ) -> PIL.Image.Image: ...\n    @staticmethod\n    def expand_image(\n        image: PIL.Image.Image,\n        box_values: AbsoluteBoxValues | None = ...,\n        top: int | str = ...,\n        right: int | str = ...,\n        bottom: int | str = ...,\n        left: int | str = ...,\n        fill_color: tuple = ...,\n    ) -> PIL.Image.Image: ...\n    @staticmethod\n    def create_outpaint_mask_image(\n        orig_width: int, orig_height: int, **create_bordered_image_kwargs\n    ):  # -> Image:\n        ...\n    @staticmethod\n    def create_bordered_image(\n        orig_width: int,\n        orig_height: int,\n        border_color: tuple,\n        content_color: tuple,\n        box_values: AbsoluteBoxValues | None = ...,\n        top: int | str = ...,\n        right: int | str = ...,\n        bottom: int | str = ...,\n        left: int | str = ...,\n    ) -> PIL.Image.Image: ...\n    @staticmethod\n    def scale_to_dimensions(\n        image: PIL.Image.Image, target_width: int, target_height: int\n    ) -> PIL.Image.Image: ...\n    @staticmethod\n    def save_image(\n        image: PIL.Image.Image,\n        path: str | Path,\n        metadata: dict | None = ...,\n        export_json_metadata: bool = ...,\n        overwrite: bool = ...,\n    ) -> None: ...\n    @staticmethod\n    def preprocess_for_model(\n        image: PIL.Image.Image,\n        target_size: tuple = ...,\n        mean: list = ...,\n        std: list = ...,\n        resample: int = ...,\n    ) -> mx.array: ...\n    @staticmethod\n    def preprocess_for_depth_pro(\n        image: PIL.Image.Image,\n        target_size: tuple = ...,\n        mean: list = ...,\n        std: list = ...,\n        resample: int = ...,\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/utils/metadata_builder.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom pathlib import Path\n\nlog = ...\n\nclass MetadataBuilder:\n    _IPTC_PROMPT_MAX_BYTES = ...\n    @staticmethod\n    def embed_metadata(metadata: dict, path: str | Path) -> None: ...\n    @staticmethod\n    def build_xmp_packet(metadata: dict) -> str: ...\n    @staticmethod\n    def build_iptc_binary(metadata: dict) -> bytes: ...\n"
  },
  {
    "path": ".mlx_typings/mflux/utils/version_util.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\n\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nclass VersionUtil:\n    @staticmethod\n    def get_mflux_version() -> str: ...\n"
  },
  {
    "path": ".mlx_typings/mlx/core/__init__.pyi",
    "content": "import enum\nimport pathlib\nimport types\nfrom typing import (\n    Annotated,\n    Callable,\n    Literal,\n    Mapping,\n    Sequence,\n    TypeAlias,\n    overload,\n)\n\nimport numpy\nfrom mlx.nn.layers import Module\nfrom numpy.typing import ArrayLike as _ArrayLike\n\nfrom . import cuda as cuda\nfrom . import distributed as distributed\nfrom . import metal as metal\nfrom . import random as random\n\nclass ArrayAt:\n    \"\"\"A helper object to apply updates at specific indices.\"\"\"\n    def __getitem__(self, indices: object | None) -> ArrayAt: ...\n    def add(\n        self,\n        value: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def subtract(\n        self,\n        value: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def multiply(\n        self,\n        value: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def divide(\n        self,\n        value: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def maximum(\n        self,\n        value: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def minimum(\n        self,\n        value: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n\nclass ArrayIterator:\n    \"\"\"A helper object to iterate over the 1st dimension of an array.\"\"\"\n    def __next__(self) -> array: ...\n    def __iter__(self) -> ArrayIterator: ...\n\nclass ArrayLike:\n    \"\"\"\n    Any Python object which has an ``__mlx__array__`` method that\n    returns an :obj:`array`.\n    \"\"\"\n    def __init__(self, arg: object, /) -> None: ...\n\nclass Device:\n    \"\"\"A device to run operations on.\"\"\"\n    def __init__(self, type: DeviceType, index: int = ...) -> None: ...\n    @property\n    def type(self) -> DeviceType: ...\n    def __repr__(self) -> str: ...\n    def __eq__(self, arg: object, /) -> bool: ...\n\nclass DeviceType(enum.Enum):\n    cpu = ...  # type: ignore\n    gpu = ...  #  type: ignore\n    def __eq__(self, arg: object, /) -> bool: ...\n\nclass Dtype:\n    \"\"\"\n    An object to hold the type of a :class:`array`.\n\n    See the :ref:`list of types <data_types>` for more details\n    on available data types.\n    \"\"\"\n    @property\n    def size(self) -> int:\n        \"\"\"Size of the type in bytes.\"\"\"\n\n    def __repr__(self) -> str: ...\n    def __eq__(self, arg: object, /) -> bool: ...\n    def __hash__(self) -> int: ...\n\nclass DtypeCategory(enum.Enum):\n    \"\"\"\n    Type to hold categories of :class:`dtypes <Dtype>`.\n\n    * :attr:`~mlx.core.generic`\n\n      * :ref:`bool_ <data_types>`\n      * :attr:`~mlx.core.number`\n\n        * :attr:`~mlx.core.integer`\n\n          * :attr:`~mlx.core.unsignedinteger`\n\n            * :ref:`uint8 <data_types>`\n            * :ref:`uint16 <data_types>`\n            * :ref:`uint32 <data_types>`\n            * :ref:`uint64 <data_types>`\n\n          * :attr:`~mlx.core.signedinteger`\n\n            * :ref:`int8 <data_types>`\n            * :ref:`int32 <data_types>`\n            * :ref:`int64 <data_types>`\n\n        * :attr:`~mlx.core.inexact`\n\n          * :attr:`~mlx.core.floating`\n\n            * :ref:`float16 <data_types>`\n            * :ref:`bfloat16 <data_types>`\n            * :ref:`float32 <data_types>`\n            * :ref:`float64 <data_types>`\n\n          * :attr:`~mlx.core.complexfloating`\n\n            * :ref:`complex64 <data_types>`\n\n    See also :func:`~mlx.core.issubdtype`.\n    \"\"\"\n\n    complexfloating = ...\n    floating = ...\n    inexact = ...\n    signedinteger = ...\n    unsignedinteger = ...\n    integer = ...\n    number = ...\n    generic = ...\n\nclass FunctionExporter:\n    \"\"\"\n    A context managing class for exporting multiple traces of the same\n    function to a file.\n\n    Make an instance of this class by calling fun:`mx.exporter`.\n    \"\"\"\n    def close(self) -> None: ...\n    def __enter__(self) -> FunctionExporter: ...\n    def __exit__(\n        self,\n        exc_type: object | None = ...,\n        exc_value: object | None = ...,\n        traceback: object | None = ...,\n    ) -> None: ...\n    def __call__(self, *args, **kwargs) -> None: ...\n\nclass Stream:\n    \"\"\"A stream for running operations on a given device.\"\"\"\n    @property\n    def device(self) -> Device: ...\n    def __repr__(self) -> str: ...\n    def __eq__(self, arg: object, /) -> bool: ...\n\nclass StreamContext:\n    \"\"\"\n    A context manager for setting the current device and stream.\n\n    See :func:`stream` for usage.\n\n    Args:\n        s: The stream or device to set as the default.\n    \"\"\"\n    def __init__(self, s: Stream | Device) -> None: ...\n    def __enter__(self) -> None: ...\n    def __exit__(\n        self,\n        exc_type: type | None = ...,\n        exc_value: object | None = ...,\n        traceback: object | None = ...,\n    ) -> None: ...\n\ndef device_info() -> dict[str, str | int]:\n    \"\"\"\n    Get information about the GPU device and system settings.\n\n    Currently returns:\n\n    * ``architecture``\n    * ``max_buffer_size``\n    * ``max_recommended_working_set_size``\n    * ``memory_size``\n    * ``resource_limit``\n\n    Returns:\n        dict: A dictionary with string keys and string or integer values.\n    \"\"\"\n\ndef abs(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise absolute value.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The absolute value of ``a``.\n    \"\"\"\n\ndef add(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise addition.\n\n    Add two arrays with numpy-style broadcasting semantics. Either or both input arrays\n    can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The sum of ``a`` and ``b``.\n    \"\"\"\n\ndef addmm(\n    c: array,\n    a: array,\n    b: array,\n    /,\n    alpha: float = ...,\n    beta: float = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Matrix multiplication with addition and optional scaling.\n\n    Perform the (possibly batched) matrix multiplication of two arrays and add to the result\n    with optional scaling factors.\n\n    Args:\n        c (array): Input array or scalar.\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n        alpha (float, optional): Scaling factor for the\n            matrix product of ``a`` and ``b`` (default: ``1``)\n        beta (float, optional): Scaling factor for ``c`` (default: ``1``)\n\n    Returns:\n        array: ``alpha * (a @ b)  + beta * c``\n    \"\"\"\n\ndef all(\n    a: array,\n    /,\n    axis: int | Sequence[int] | None = ...,\n    keepdims: bool = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    An `and` reduction over the given axes.\n\n    Args:\n        a (array): Input array.\n        axis (int or list(int), optional): Optional axis or\n          axes to reduce over. If unspecified this defaults\n          to reducing over the entire array.\n        keepdims (bool, optional): Keep reduced axes as\n          singleton dimensions, defaults to `False`.\n\n    Returns:\n        array: The output array with the corresponding axes reduced.\n    \"\"\"\n\ndef allclose(\n    a: array,\n    b: array,\n    /,\n    rtol: float = ...,\n    atol: float = ...,\n    *,\n    equal_nan: bool = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Approximate comparison of two arrays.\n\n    Infinite values are considered equal if they have the same sign, NaN values are not equal unless ``equal_nan`` is ``True``.\n\n    The arrays are considered equal if:\n\n    .. code-block::\n\n     all(abs(a - b) <= (atol + rtol * abs(b)))\n\n    Note unlike :func:`array_equal`, this function supports numpy-style\n    broadcasting.\n\n    Args:\n        a (array): Input array.\n        b (array): Input array.\n        rtol (float): Relative tolerance.\n        atol (float): Absolute tolerance.\n        equal_nan (bool): If ``True``, NaNs are considered equal.\n          Defaults to ``False``.\n\n    Returns:\n        array: The boolean output scalar indicating if the arrays are close.\n    \"\"\"\n\ndef any(\n    a: array,\n    /,\n    axis: int | Sequence[int] | None = ...,\n    keepdims: bool = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    An `or` reduction over the given axes.\n\n    Args:\n        a (array): Input array.\n        axis (int or list(int), optional): Optional axis or\n          axes to reduce over. If unspecified this defaults\n          to reducing over the entire array.\n        keepdims (bool, optional): Keep reduced axes as\n          singleton dimensions, defaults to `False`.\n\n    Returns:\n        array: The output array with the corresponding axes reduced.\n    \"\"\"\n\n@overload\ndef arange(\n    start: int | float,\n    stop: int | float,\n    step: int | float | None,\n    dtype: Dtype | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Generates ranges of numbers.\n\n    Generate numbers in the half-open interval ``[start, stop)`` in\n    increments of ``step``.\n\n    Args:\n        start (float or int, optional): Starting value which defaults to ``0``.\n        stop (float or int): Stopping value.\n        step (float or int, optional): Increment which defaults to ``1``.\n        dtype (Dtype, optional): Specifies the data type of the output. If unspecified will default to ``float32`` if any of ``start``, ``stop``, or ``step`` are ``float``. Otherwise will default to ``int32``.\n\n    Returns:\n        array: The range of values.\n\n    Note:\n      Following the Numpy convention the actual increment used to\n      generate numbers is ``dtype(start + step) - dtype(start)``.\n      This can lead to unexpected results for example if `start + step`\n      is a fractional value and the `dtype` is integral.\n    \"\"\"\n\n@overload\ndef arange(\n    stop: int | float,\n    step: int | float | None = ...,\n    dtype: Dtype | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array: ...\ndef arccos(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise inverse cosine.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The inverse cosine of ``a``.\n    \"\"\"\n\ndef arccosh(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise inverse hyperbolic cosine.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The inverse hyperbolic cosine of ``a``.\n    \"\"\"\n\ndef arcsin(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise inverse sine.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The inverse sine of ``a``.\n    \"\"\"\n\ndef arcsinh(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise inverse hyperbolic sine.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The inverse hyperbolic sine of ``a``.\n    \"\"\"\n\ndef arctan(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise inverse tangent.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The inverse tangent of ``a``.\n    \"\"\"\n\ndef arctan2(a: array, b: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise inverse tangent of the ratio of two arrays.\n\n    Args:\n        a (array): Input array.\n        b (array): Input array.\n\n    Returns:\n        array: The inverse tangent of the ratio of ``a`` and ``b``.\n    \"\"\"\n\ndef arctanh(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise inverse hyperbolic tangent.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The inverse hyperbolic tangent of ``a``.\n    \"\"\"\n\ndef argmax(\n    a: array,\n    /,\n    axis: int | None = ...,\n    keepdims: bool = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Indices of the maximum values along the axis.\n\n    Args:\n        a (array): Input array.\n        axis (int, optional): Optional axis to reduce over. If unspecified\n          this defaults to reducing over the entire array.\n        keepdims (bool, optional): Keep reduced axes as\n          singleton dimensions, defaults to `False`.\n\n    Returns:\n        array: The ``uint32`` array with the indices of the maximum values.\n    \"\"\"\n\ndef argmin(\n    a: array,\n    /,\n    axis: int | None = ...,\n    keepdims: bool = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Indices of the minimum values along the axis.\n\n    Args:\n        a (array): Input array.\n        axis (int, optional): Optional axis to reduce over. If unspecified\n          this defaults to reducing over the entire array.\n        keepdims (bool, optional): Keep reduced axes as\n          singleton dimensions, defaults to `False`.\n\n    Returns:\n        array: The ``uint32`` array with the indices of the minimum values.\n    \"\"\"\n\ndef argpartition(\n    a: array,\n    /,\n    kth: int,\n    axis: int | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Returns the indices that partition the array.\n\n    The ordering of the elements within a partition in given by the indices\n    is undefined.\n\n    Args:\n        a (array): Input array.\n        kth (int): Element index at the ``kth`` position in the output will\n          give the sorted position. All indices before the ``kth`` position\n          will be of elements less or equal to the element at the ``kth``\n          index and all indices after will be of elements greater or equal\n          to the element at the ``kth`` index.\n        axis (int or None, optional): Optional axis to partition over.\n          If ``None``, this partitions over the flattened array.\n          If unspecified, it defaults to ``-1``.\n\n    Returns:\n        array: The ``uint32`` array containing indices that partition the input.\n    \"\"\"\n\ndef argsort(\n    a: array,\n    /,\n    axis: int | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Returns the indices that sort the array.\n\n    Args:\n        a (array): Input array.\n        axis (int or None, optional): Optional axis to sort over.\n          If ``None``, this sorts over the flattened array.\n          If unspecified, it defaults to -1 (sorting over the last axis).\n\n    Returns:\n        array: The ``uint32`` array containing indices that sort the input.\n    \"\"\"\n\nclass array:\n    \"\"\"An N-dimensional array object.\"\"\"\n    def __init__(\n        self: array,\n        val: scalar | list | tuple | numpy.ndarray | array,\n        dtype: Dtype | None = ...,\n    ) -> None: ...\n    def __buffer__(self, flags, /):\n        \"\"\"\n        Return a buffer object that exposes the underlying memory of the object.\n        \"\"\"\n\n    def __release_buffer__(self, buffer, /):\n        \"\"\"\n        Release the buffer object that exposes the underlying memory of the object.\n        \"\"\"\n\n    @property\n    def size(self) -> int:\n        \"\"\"Number of elements in the array.\"\"\"\n\n    @property\n    def ndim(self) -> int:\n        \"\"\"The array's dimension.\"\"\"\n\n    @property\n    def itemsize(self) -> int:\n        \"\"\"The size of the array's datatype in bytes.\"\"\"\n\n    @property\n    def nbytes(self) -> int:\n        \"\"\"The number of bytes in the array.\"\"\"\n\n    @property\n    def shape(self) -> tuple[int, ...]:\n        \"\"\"\n        The shape of the array as a Python tuple.\n\n        Returns:\n          tuple(int): A tuple containing the sizes of each dimension.\n        \"\"\"\n\n    @property\n    def dtype(self) -> Dtype:\n        \"\"\"The array's :class:`Dtype`.\"\"\"\n\n    @property\n    def real(self) -> array:\n        \"\"\"The real part of a complex array.\"\"\"\n\n    @property\n    def imag(self) -> array:\n        \"\"\"The imaginary part of a complex array.\"\"\"\n\n    def item(self) -> scalar:\n        \"\"\"\n        Access the value of a scalar array.\n\n        Returns:\n            Standard Python scalar.\n        \"\"\"\n\n    def tolist(self) -> list_or_scalar:\n        \"\"\"\n        Convert the array to a Python :class:`list`.\n\n        Returns:\n            list: The Python list.\n\n            If the array is a scalar then a standard Python scalar is returned.\n\n            If the array has more than one dimension then the result is a nested\n            list of lists.\n\n            The value type of the list corresponding to the last dimension is either\n            ``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array.\n        \"\"\"\n\n    def astype(self, dtype: Dtype, stream: Stream | Device | None = ...) -> array:\n        \"\"\"\n        Cast the array to a specified type.\n\n        Args:\n            dtype (Dtype): Type to which the array is cast.\n            stream (Stream): Stream (or device) for the operation.\n\n        Returns:\n            array: The array with type ``dtype``.\n        \"\"\"\n\n    def __array_namespace__(self, api_version: str | None = ...) -> types.ModuleType:\n        \"\"\"\n        Returns an object that has all the array API functions on it.\n\n        See the `Python array API <https://data-apis.org/array-api/latest/index.html>`_\n        for more information.\n\n        Args:\n            api_version (str, optional): String representing the version\n              of the array API spec to return. Default: ``None``.\n\n        Returns:\n            out (Any): An object representing the array API namespace.\n        \"\"\"\n\n    def __getitem__(self, arg: object | None) -> array: ...\n    def __setitem__(\n        self,\n        arg0: object | None,\n        arg1: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> None: ...\n    @property\n    def at(self) -> ArrayAt:\n        \"\"\"\n        Used to apply updates at the given indices.\n\n        .. note::\n\n           Regular in-place updates map to assignment. For instance ``x[idx] += y``\n           maps to ``x[idx] = x[idx] + y``. As a result, assigning to the\n           same index ignores all but one update. Using ``x.at[idx].add(y)``\n           will correctly apply all updates to all indices.\n\n        .. list-table::\n           :header-rows: 1\n\n           * - array.at syntax\n             - In-place syntax\n           * - ``x = x.at[idx].add(y)``\n             - ``x[idx] += y``\n           * - ``x = x.at[idx].subtract(y)``\n             - ``x[idx] -= y``\n           * - ``x = x.at[idx].multiply(y)``\n             - ``x[idx] *= y``\n           * - ``x = x.at[idx].divide(y)``\n             - ``x[idx] /= y``\n           * - ``x = x.at[idx].maximum(y)``\n             - ``x[idx] = mx.maximum(x[idx], y)``\n           * - ``x = x.at[idx].minimum(y)``\n             - ``x[idx] = mx.minimum(x[idx], y)``\n\n        Example:\n            >>> a = mx.array([0, 0])\n            >>> idx = mx.array([0, 1, 0, 1])\n            >>> a[idx] += 1\n            >>> a\n            array([1, 1], dtype=int32)\n            >>>\n            >>> a = mx.array([0, 0])\n            >>> a.at[idx].add(1)\n            array([2, 2], dtype=int32)\n        \"\"\"\n\n    def __len__(self) -> int: ...\n    def __iter__(self) -> ArrayIterator: ...\n    def __getstate__(self) -> tuple: ...\n    def __setstate__(self, arg: tuple, /) -> None: ...\n    def __dlpack__(self) -> _ArrayLike: ...\n    def __dlpack_device__(self) -> tuple: ...\n    def __copy__(self) -> array: ...\n    def __deepcopy__(self, memo: dict) -> array: ...\n    def __add__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __iadd__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __radd__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __sub__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __isub__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __rsub__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __mul__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __imul__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __rmul__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __truediv__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __itruediv__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __rtruediv__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __div__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __rdiv__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __floordiv__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __ifloordiv__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __rfloordiv__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __mod__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __imod__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __rmod__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __eq__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array | bool: ...\n    def __lt__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __le__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __gt__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __ge__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __ne__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array | bool: ...\n    def __neg__(self) -> array: ...\n    def __bool__(self) -> bool: ...\n    def __repr__(self) -> str: ...\n    def __matmul__(self, other: array) -> array: ...\n    def __imatmul__(self, other: array) -> array: ...\n    def __pow__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __rpow__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __ipow__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __invert__(self) -> array: ...\n    def __and__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __iand__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __or__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __ior__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __lshift__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __ilshift__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __rshift__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __irshift__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __xor__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __ixor__(\n        self,\n        other: bool\n        | int\n        | float\n        | array\n        | Annotated[_ArrayLike, dict(order=\"C\", device=\"cpu\", writable=False)]\n        | complex\n        | ArrayLike,\n    ) -> array: ...\n    def __int__(self) -> int: ...\n    def __float__(self) -> float: ...\n    def flatten(\n        self,\n        start_axis: int = ...,\n        end_axis: int = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`flatten`.\"\"\"\n\n    def reshape(self, *shape: int, stream: Stream | Device | None = ...) -> array:\n        \"\"\"\n        Equivalent to :func:`reshape` but the shape can be passed either as a\n        :obj:`tuple` or as separate arguments.\n\n        See :func:`reshape` for full documentation.\n        \"\"\"\n\n    def squeeze(\n        self,\n        axis: int | Sequence[int] | None = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`squeeze`.\"\"\"\n\n    def abs(self, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`abs`.\"\"\"\n\n    def __abs__(self) -> array:\n        \"\"\"See :func:`abs`.\"\"\"\n\n    def square(self, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`square`.\"\"\"\n\n    def sqrt(self, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`sqrt`.\"\"\"\n\n    def rsqrt(self, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`rsqrt`.\"\"\"\n\n    def reciprocal(self, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`reciprocal`.\"\"\"\n\n    def exp(self, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`exp`.\"\"\"\n\n    def log(self, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`log`.\"\"\"\n\n    def log2(self, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`log2`.\"\"\"\n\n    def log10(self, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`log10`.\"\"\"\n\n    def sin(self, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`sin`.\"\"\"\n\n    def cos(self, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`cos`.\"\"\"\n\n    def log1p(self, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`log1p`.\"\"\"\n\n    def all(\n        self,\n        axis: int | Sequence[int] | None = ...,\n        keepdims: bool = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`all`.\"\"\"\n\n    def any(\n        self,\n        axis: int | Sequence[int] | None = ...,\n        keepdims: bool = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`any`.\"\"\"\n\n    def moveaxis(\n        self, source: int, destination: int, *, stream: Stream | Device | None = ...\n    ) -> array:\n        \"\"\"See :func:`moveaxis`.\"\"\"\n\n    def swapaxes(\n        self, axis1: int, axis2: int, *, stream: Stream | Device | None = ...\n    ) -> array:\n        \"\"\"See :func:`swapaxes`.\"\"\"\n\n    def transpose(self, *axes: int, stream: Stream | Device | None = ...) -> array:\n        \"\"\"\n        Equivalent to :func:`transpose` but the axes can be passed either as\n        a tuple or as separate arguments.\n\n        See :func:`transpose` for full documentation.\n        \"\"\"\n\n    @property\n    def T(self) -> array:\n        \"\"\"Equivalent to calling ``self.transpose()`` with no arguments.\"\"\"\n\n    def sum(\n        self,\n        axis: int | Sequence[int] | None = ...,\n        keepdims: bool = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`sum`.\"\"\"\n\n    def prod(\n        self,\n        axis: int | Sequence[int] | None = ...,\n        keepdims: bool = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`prod`.\"\"\"\n\n    def min(\n        self,\n        axis: int | Sequence[int] | None = ...,\n        keepdims: bool = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`min`.\"\"\"\n\n    def max(\n        self,\n        axis: int | Sequence[int] | None = ...,\n        keepdims: bool = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`max`.\"\"\"\n\n    def logcumsumexp(\n        self,\n        axis: int | None = ...,\n        *,\n        reverse: bool = ...,\n        inclusive: bool = ...,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`logcumsumexp`.\"\"\"\n\n    def logsumexp(\n        self,\n        axis: int | Sequence[int] | None = ...,\n        keepdims: bool = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`logsumexp`.\"\"\"\n\n    def mean(\n        self,\n        axis: int | Sequence[int] | None = ...,\n        keepdims: bool = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`mean`.\"\"\"\n\n    def std(\n        self,\n        axis: int | Sequence[int] | None = ...,\n        keepdims: bool = ...,\n        ddof: int = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`std`.\"\"\"\n\n    def var(\n        self,\n        axis: int | Sequence[int] | None = ...,\n        keepdims: bool = ...,\n        ddof: int = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`var`.\"\"\"\n\n    def split(\n        self,\n        indices_or_sections: int | tuple[int, ...],\n        axis: int = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> list[array]:\n        \"\"\"See :func:`split`.\"\"\"\n\n    def argmin(\n        self,\n        axis: int | None = ...,\n        keepdims: bool = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`argmin`.\"\"\"\n\n    def argmax(\n        self,\n        axis: int | None = ...,\n        keepdims: bool = ...,\n        *,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`argmax`.\"\"\"\n\n    def cumsum(\n        self,\n        axis: int | None = ...,\n        *,\n        reverse: bool = ...,\n        inclusive: bool = ...,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`cumsum`.\"\"\"\n\n    def cumprod(\n        self,\n        axis: int | None = ...,\n        *,\n        reverse: bool = ...,\n        inclusive: bool = ...,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`cumprod`.\"\"\"\n\n    def cummax(\n        self,\n        axis: int | None = ...,\n        *,\n        reverse: bool = ...,\n        inclusive: bool = ...,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`cummax`.\"\"\"\n\n    def cummin(\n        self,\n        axis: int | None = ...,\n        *,\n        reverse: bool = ...,\n        inclusive: bool = ...,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`cummin`.\"\"\"\n\n    def round(\n        self, decimals: int = ..., *, stream: Stream | Device | None = ...\n    ) -> array:\n        \"\"\"See :func:`round`.\"\"\"\n\n    def diagonal(\n        self,\n        offset: int = ...,\n        axis1: int = ...,\n        axis2: int = ...,\n        stream: Stream | Device | None = ...,\n    ) -> array:\n        \"\"\"See :func:`diagonal`.\"\"\"\n\n    def diag(self, k: int = ..., *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"Extract a diagonal or construct a diagonal matrix.\"\"\"\n\n    def conj(self, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`conj`.\"\"\"\n\n    def view(self, dtype: Dtype, *, stream: Stream | Device | None = ...) -> array:\n        \"\"\"See :func:`view`.\"\"\"\n\ndef array_equal(\n    a: scalar | array,\n    b: scalar | array,\n    equal_nan: bool = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Array equality check.\n\n    Compare two arrays for equality. Returns ``True`` if and only if the arrays\n    have the same shape and their values are equal. The arrays need not have\n    the same type to be considered equal.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n        equal_nan (bool): If ``True``, NaNs are considered equal.\n          Defaults to ``False``.\n\n    Returns:\n        array: A scalar boolean array.\n    \"\"\"\n\ndef as_strided(\n    a: array,\n    /,\n    shape: Sequence[int] | None = ...,\n    strides: Sequence[int] | None = ...,\n    offset: int = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Create a view into the array with the given shape and strides.\n\n    The resulting array will always be as if the provided array was row\n    contiguous regardless of the provided arrays storage order and current\n    strides.\n\n    .. note::\n       Note that this function should be used with caution as it changes\n       the shape and strides of the array directly. This can lead to the\n       resulting array pointing to invalid memory locations which can\n       result into crashes.\n\n    Args:\n      a (array): Input array\n      shape (list(int), optional): The shape of the resulting array. If\n        None it defaults to ``a.shape()``.\n      strides (list(int), optional): The strides of the resulting array. If\n        None it defaults to the reverse exclusive cumulative product of\n        ``a.shape()``.\n      offset (int): Skip that many elements from the beginning of the input\n        array.\n\n    Returns:\n      array: The output array which is the strided view of the input.\n    \"\"\"\n\ndef async_eval(*args: MX_ARRAY_TREE) -> None:\n    \"\"\"\n    Asynchronously evaluate an :class:`array` or tree of :class:`array`.\n\n    .. note::\n\n      This is an experimental API and may change in future versions.\n\n    Args:\n        *args (arrays or trees of arrays): Each argument can be a single array\n          or a tree of arrays. If a tree is given the nodes can be a Python\n          :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not\n          arrays are ignored.\n\n    Example:\n        >>> x = mx.array(1.0)\n        >>> y = mx.exp(x)\n        >>> mx.async_eval(y)\n        >>> print(y)\n        >>>\n        >>> y = mx.exp(x)\n        >>> mx.async_eval(y)\n        >>> z = y + 3\n        >>> mx.async_eval(z)\n        >>> print(z)\n    \"\"\"\n\ndef atleast_1d(\n    *arys: array, stream: Stream | Device | None = ...\n) -> array | list[array]:\n    \"\"\"\n    Convert all arrays to have at least one dimension.\n\n    Args:\n        *arys: Input arrays.\n        stream (Stream | Device | None, optional): The stream to execute the operation on.\n\n    Returns:\n        array or list(array): An array or list of arrays with at least one dimension.\n    \"\"\"\n\ndef atleast_2d(\n    *arys: array, stream: Stream | Device | None = ...\n) -> array | list[array]:\n    \"\"\"\n    Convert all arrays to have at least two dimensions.\n\n    Args:\n        *arys: Input arrays.\n        stream (Stream | Device | None, optional): The stream to execute the operation on.\n\n    Returns:\n        array or list(array): An array or list of arrays with at least two dimensions.\n    \"\"\"\n\ndef atleast_3d(\n    *arys: array, stream: Stream | Device | None = ...\n) -> array | list[array]:\n    \"\"\"\n    Convert all arrays to have at least three dimensions.\n\n    Args:\n        *arys: Input arrays.\n        stream (Stream | Device | None, optional): The stream to execute the operation on.\n\n    Returns:\n        array or list(array): An array or list of arrays with at least three dimensions.\n    \"\"\"\n\nbfloat16: Dtype = ...\n\ndef bitwise_and(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise bitwise and.\n\n    Take the bitwise and of two arrays with numpy-style broadcasting\n    semantics. Either or both input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The bitwise and ``a & b``.\n    \"\"\"\n\ndef bitwise_invert(a: scalar | array, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise bitwise inverse.\n\n    Take the bitwise complement of the input.\n\n    Args:\n        a (array): Input array or scalar.\n\n    Returns:\n        array: The bitwise inverse ``~a``.\n    \"\"\"\n\ndef bitwise_or(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise bitwise or.\n\n    Take the bitwise or of two arrays with numpy-style broadcasting\n    semantics. Either or both input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The bitwise or``a | b``.\n    \"\"\"\n\ndef bitwise_xor(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise bitwise xor.\n\n    Take the bitwise exclusive or of two arrays with numpy-style\n    broadcasting semantics. Either or both input arrays can also be\n    scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The bitwise xor ``a ^ b``.\n    \"\"\"\n\ndef block_masked_mm(\n    a: array,\n    b: array,\n    /,\n    block_size: int = ...,\n    mask_out: array | None = ...,\n    mask_lhs: array | None = ...,\n    mask_rhs: array | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    r\"\"\"\n    Matrix multiplication with block masking.\n\n    Perform the (possibly batched) matrix multiplication of two arrays and with blocks\n    of size ``block_size x block_size`` optionally masked out.\n\n    Assuming ``a`` with shape (..., `M`, `K`) and b with shape (..., `K`, `N`)\n\n    * ``lhs_mask`` must have shape (..., :math:`\\lceil` `M` / ``block_size`` :math:`\\rceil`, :math:`\\lceil` `K` / ``block_size`` :math:`\\rceil`)\n\n    * ``rhs_mask`` must have shape (..., :math:`\\lceil` `K` / ``block_size`` :math:`\\rceil`, :math:`\\lceil` `N` / ``block_size`` :math:`\\rceil`)\n\n    * ``out_mask`` must have shape (..., :math:`\\lceil` `M` / ``block_size`` :math:`\\rceil`, :math:`\\lceil` `N` / ``block_size`` :math:`\\rceil`)\n\n    Note: Only ``block_size=64`` and ``block_size=32`` are currently supported\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n        block_size (int): Size of blocks to be masked. Must be ``32`` or ``64``. Default: ``64``.\n        mask_out (array, optional): Mask for output. Default: ``None``.\n        mask_lhs (array, optional): Mask for ``a``. Default: ``None``.\n        mask_rhs (array, optional): Mask for ``b``. Default: ``None``.\n\n    Returns:\n        array: The output array.\n    \"\"\"\n\ndef broadcast_arrays(\n    *arrays: array, stream: Stream | Device | None = ...\n) -> tuple[array, ...]:\n    \"\"\"\n    Broadcast arrays against one another.\n\n    The broadcasting semantics are the same as Numpy.\n\n    Args:\n        *arrays (array): The input arrays.\n\n    Returns:\n        tuple(array): The output arrays with the broadcasted shape.\n    \"\"\"\n\ndef broadcast_shapes(*shapes: Sequence[int]) -> tuple[int]:\n    \"\"\"\n    Broadcast shapes.\n\n    Returns the shape that results from broadcasting the supplied array shapes\n    against each other.\n\n    Args:\n        *shapes (Sequence[int]): The shapes to broadcast.\n\n    Returns:\n        tuple: The broadcasted shape.\n\n    Raises:\n        ValueError: If the shapes cannot be broadcast.\n\n    Example:\n        >>> mx.broadcast_shapes((1,), (3, 1))\n        (3, 1)\n        >>> mx.broadcast_shapes((6, 7), (5, 6, 1), (7,))\n        (5, 6, 7)\n        >>> mx.broadcast_shapes((5, 1, 4), (1, 3, 1))\n        (5, 3, 4)\n    \"\"\"\n\ndef broadcast_to(\n    a: scalar | array,\n    /,\n    shape: Sequence[int],\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Broadcast an array to the given shape.\n\n    The broadcasting semantics are the same as Numpy.\n\n    Args:\n        a (array): Input array.\n        shape (list(int)): The shape to broadcast to.\n\n    Returns:\n        array: The output array with the new shape.\n    \"\"\"\n\ndef ceil(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise ceil.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The ceil of ``a``.\n    \"\"\"\n\ndef checkpoint(fun: Callable) -> Callable: ...\ndef clear_cache() -> None:\n    \"\"\"\n    Clear the memory cache.\n\n    After calling this, :func:`get_cache_memory` should return ``0``.\n    \"\"\"\n\ndef clip(\n    a: array,\n    /,\n    a_min: scalar | array | None,\n    a_max: scalar | array | None,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Clip the values of the array between the given minimum and maximum.\n\n    If either ``a_min`` or ``a_max`` are ``None``, then corresponding edge\n    is ignored. At least one of ``a_min`` and ``a_max`` cannot be ``None``.\n    The input ``a`` and the limits must broadcast with one another.\n\n    Args:\n        a (array): Input array.\n        a_min (scalar or array or None): Minimum value to clip to.\n        a_max (scalar or array or None): Maximum value to clip to.\n\n    Returns:\n        array: The clipped array.\n    \"\"\"\n\ndef compile(\n    fun: Callable,\n    inputs: object | None = ...,\n    outputs: object | None = ...,\n    shapeless: bool = ...,\n) -> Callable:\n    \"\"\"\n    Returns a compiled function which produces the same output as ``fun``.\n\n    Args:\n        fun (Callable): A function which takes a variable number of\n          :class:`array` or trees of :class:`array` and returns\n          a variable number of :class:`array` or trees of :class:`array`.\n        inputs (list or dict, optional): These inputs will be captured during\n          the function compilation along with the inputs to ``fun``. The ``inputs``\n          can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested\n          lists, dictionaries, or arrays. Leaf nodes that are not\n          :obj:`array` are ignored. Default: ``None``\n        outputs (list or dict, optional): These outputs will be captured and\n          updated in a compiled function. The ``outputs`` can be a\n          :obj:`list` or a :obj:`dict` containing arbitrarily nested lists,\n          dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.\n          Default: ``None``\n        shapeless (bool, optional): A function compiled with the ``shapeless``\n          option enabled will not be recompiled when the input shape changes. Not all\n          functions can be compiled with ``shapeless`` enabled. Attempting to compile\n          such functions with shapeless enabled will throw. Note, changing the number\n          of dimensions or type of any input will result in a recompilation even with\n          ``shapeless`` set to ``True``. Default: ``False``\n\n    Returns:\n        Callable: A compiled function which has the same input arguments\n        as ``fun`` and returns the the same output(s).\n    \"\"\"\n\ncomplex64: Dtype = ...\ncomplexfloating: DtypeCategory = ...\n\ndef concat(\n    arrays: list[array],\n    axis: int | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"See :func:`concatenate`.\"\"\"\n\ndef concatenate(\n    arrays: list[array],\n    axis: int | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Concatenate the arrays along the given axis.\n\n    Args:\n        arrays (list(array)): Input :obj:`list` or :obj:`tuple` of arrays.\n        axis (int, optional): Optional axis to concatenate along. If\n          unspecified defaults to ``0``.\n\n    Returns:\n        array: The concatenated array.\n    \"\"\"\n\ndef conj(a: array, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Return the elementwise complex conjugate of the input.\n    Alias for `mx.conjugate`.\n\n    Args:\n      a (array): Input array\n\n    Returns:\n      array: The output array.\n    \"\"\"\n\ndef conjugate(a: array, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Return the elementwise complex conjugate of the input.\n    Alias for `mx.conj`.\n\n    Args:\n      a (array): Input array\n\n    Returns:\n      array: The output array.\n    \"\"\"\n\ndef contiguous(\n    a: array,\n    /,\n    allow_col_major: bool = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Force an array to be row contiguous. Copy if necessary.\n\n    Args:\n      a (array): The input to make contiguous\n      allow_col_major (bool): Consider column major as contiguous and don't copy\n\n    Returns:\n      array: The row or col contiguous output.\n    \"\"\"\n\ndef conv1d(\n    input: array,\n    weight: array,\n    /,\n    stride: int = ...,\n    padding: int = ...,\n    dilation: int = ...,\n    groups: int = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    1D convolution over an input with several channels\n\n    Args:\n        input (array): Input array of shape ``(N, L, C_in)``.\n        weight (array): Weight array of shape ``(C_out, K, C_in)``.\n        stride (int, optional): Kernel stride. Default: ``1``.\n        padding (int, optional): Input padding. Default: ``0``.\n        dilation (int, optional): Kernel dilation. Default: ``1``.\n        groups (int, optional): Input feature groups. Default: ``1``.\n\n    Returns:\n        array: The convolved array.\n    \"\"\"\n\ndef conv2d(\n    input: array,\n    weight: array,\n    /,\n    stride: int | tuple[int, int] = ...,\n    padding: int | tuple[int, int] = ...,\n    dilation: int | tuple[int, int] = ...,\n    groups: int = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    2D convolution over an input with several channels\n\n    Args:\n        input (array): Input array of shape ``(N, H, W, C_in)``.\n        weight (array): Weight array of shape ``(C_out, KH, KW, C_in)``.\n        stride (int or tuple(int), optional): :obj:`tuple` of size 2 with\n            kernel strides. All spatial dimensions get the same stride if\n            only one number is specified. Default: ``1``.\n        padding (int or tuple(int), optional): :obj:`tuple` of size 2 with\n            symmetric input padding. All spatial dimensions get the same\n            padding if only one number is specified. Default: ``0``.\n        dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with\n            kernel dilation. All spatial dimensions get the same dilation\n            if only one number is specified. Default: ``1``\n        groups (int, optional): input feature groups. Default: ``1``.\n\n    Returns:\n        array: The convolved array.\n    \"\"\"\n\ndef conv3d(\n    input: array,\n    weight: array,\n    /,\n    stride: int | tuple[int, int, int] = ...,\n    padding: int | tuple[int, int, int] = ...,\n    dilation: int | tuple[int, int, int] = ...,\n    groups: int = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    3D convolution over an input with several channels\n\n    Note: Only the default ``groups=1`` is currently supported.\n\n    Args:\n        input (array): Input array of shape ``(N, D, H, W, C_in)``.\n        weight (array): Weight array of shape ``(C_out, KD, KH, KW, C_in)``.\n        stride (int or tuple(int), optional): :obj:`tuple` of size 3 with\n            kernel strides. All spatial dimensions get the same stride if\n            only one number is specified. Default: ``1``.\n        padding (int or tuple(int), optional): :obj:`tuple` of size 3 with\n            symmetric input padding. All spatial dimensions get the same\n            padding if only one number is specified. Default: ``0``.\n        dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with\n            kernel dilation. All spatial dimensions get the same dilation\n            if only one number is specified. Default: ``1``\n        groups (int, optional): input feature groups. Default: ``1``.\n\n    Returns:\n        array: The convolved array.\n    \"\"\"\n\ndef conv_general(\n    input: array,\n    weight: array,\n    /,\n    stride: int | Sequence[int] = ...,\n    padding: int | Sequence[int] | tuple[Sequence[int] | Sequence[int]] = ...,\n    kernel_dilation: int | Sequence[int] = ...,\n    input_dilation: int | Sequence[int] = ...,\n    groups: int = ...,\n    flip: bool = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    General convolution over an input with several channels\n\n    Args:\n        input (array): Input array of shape ``(N, ..., C_in)``.\n        weight (array): Weight array of shape ``(C_out, ..., C_in)``.\n        stride (int or list(int), optional): :obj:`list` with kernel strides.\n            All spatial dimensions get the same stride if\n            only one number is specified. Default: ``1``.\n        padding (int, list(int), or tuple(list(int), list(int)), optional):\n            :obj:`list` with input padding. All spatial dimensions get the same\n            padding if only one number is specified. Default: ``0``.\n        kernel_dilation (int or list(int), optional): :obj:`list` with\n            kernel dilation. All spatial dimensions get the same dilation\n            if only one number is specified. Default: ``1``\n        input_dilation (int or list(int), optional): :obj:`list` with\n            input dilation. All spatial dimensions get the same dilation\n            if only one number is specified. Default: ``1``\n        groups (int, optional): Input feature groups. Default: ``1``.\n        flip (bool, optional): Flip the order in which the spatial dimensions of\n            the weights are processed. Performs the cross-correlation operator when\n            ``flip`` is ``False`` and the convolution operator otherwise.\n            Default: ``False``.\n\n    Returns:\n        array: The convolved array.\n    \"\"\"\n\ndef conv_transpose1d(\n    input: array,\n    weight: array,\n    /,\n    stride: int = ...,\n    padding: int = ...,\n    dilation: int = ...,\n    output_padding: int = ...,\n    groups: int = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    1D transposed convolution over an input with several channels\n\n    Args:\n        input (array): Input array of shape ``(N, L, C_in)``.\n        weight (array): Weight array of shape ``(C_out, K, C_in)``.\n        stride (int, optional): Kernel stride. Default: ``1``.\n        padding (int, optional): Input padding. Default: ``0``.\n        dilation (int, optional): Kernel dilation. Default: ``1``.\n        output_padding (int, optional): Output padding. Default: ``0``.\n        groups (int, optional): Input feature groups. Default: ``1``.\n\n    Returns:\n        array: The convolved array.\n    \"\"\"\n\ndef conv_transpose2d(\n    input: array,\n    weight: array,\n    /,\n    stride: int | tuple[int, int] = ...,\n    padding: int | tuple[int, int] = ...,\n    dilation: int | tuple[int, int] = ...,\n    output_padding: int | tuple[int, int] = ...,\n    groups: int = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    2D transposed convolution over an input with several channels\n\n    Note: Only the default ``groups=1`` is currently supported.\n\n    Args:\n        input (array): Input array of shape ``(N, H, W, C_in)``.\n        weight (array): Weight array of shape ``(C_out, KH, KW, C_in)``.\n        stride (int or tuple(int), optional): :obj:`tuple` of size 2 with\n            kernel strides. All spatial dimensions get the same stride if\n            only one number is specified. Default: ``1``.\n        padding (int or tuple(int), optional): :obj:`tuple` of size 2 with\n            symmetric input padding. All spatial dimensions get the same\n            padding if only one number is specified. Default: ``0``.\n        dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with\n            kernel dilation. All spatial dimensions get the same dilation\n            if only one number is specified. Default: ``1``\n        output_padding (int or tuple(int), optional): :obj:`tuple` of size 2 with\n            output padding. All spatial dimensions get the same output\n            padding if only one number is specified. Default: ``0``.\n        groups (int, optional): input feature groups. Default: ``1``.\n\n    Returns:\n        array: The convolved array.\n    \"\"\"\n\ndef conv_transpose3d(\n    input: array,\n    weight: array,\n    /,\n    stride: int | tuple[int, int, int] = ...,\n    padding: int | tuple[int, int, int] = ...,\n    dilation: int | tuple[int, int, int] = ...,\n    output_padding: int | tuple[int, int, int] = ...,\n    groups: int = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    3D transposed convolution over an input with several channels\n\n    Note: Only the default ``groups=1`` is currently supported.\n\n    Args:\n        input (array): Input array of shape ``(N, D, H, W, C_in)``.\n        weight (array): Weight array of shape ``(C_out, KD, KH, KW, C_in)``.\n        stride (int or tuple(int), optional): :obj:`tuple` of size 3 with\n            kernel strides. All spatial dimensions get the same stride if\n            only one number is specified. Default: ``1``.\n        padding (int or tuple(int), optional): :obj:`tuple` of size 3 with\n            symmetric input padding. All spatial dimensions get the same\n            padding if only one number is specified. Default: ``0``.\n        dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with\n            kernel dilation. All spatial dimensions get the same dilation\n            if only one number is specified. Default: ``1``\n        output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with\n            output padding. All spatial dimensions get the same output\n            padding if only one number is specified. Default: ``0``.\n        groups (int, optional): input feature groups. Default: ``1``.\n\n    Returns:\n        array: The convolved array.\n    \"\"\"\n\ndef convolve(\n    a: array, v: array, /, mode: str = ..., *, stream: Stream | Device | None = ...\n) -> array:\n    \"\"\"\n    The discrete convolution of 1D arrays.\n\n    If ``v`` is longer than ``a``, then they are swapped.\n    The conv filter is flipped following signal processing convention.\n\n    Args:\n        a (array): 1D Input array.\n        v (array): 1D Input array.\n        mode (str, optional): {'full', 'valid', 'same'}\n\n    Returns:\n        array: The convolved array.\n    \"\"\"\n\ndef cos(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise cosine.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The cosine of ``a``.\n    \"\"\"\n\ndef cosh(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise hyperbolic cosine.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The hyperbolic cosine of ``a``.\n    \"\"\"\n\ncpu: DeviceType = ...\n\ndef cummax(\n    a: array,\n    /,\n    axis: int | None = ...,\n    *,\n    reverse: bool = ...,\n    inclusive: bool = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Return the cumulative maximum of the elements along the given axis.\n\n    Args:\n      a (array): Input array\n      axis (int, optional): Optional axis to compute the cumulative maximum\n        over. If unspecified the cumulative maximum of the flattened array is\n        returned.\n      reverse (bool): Perform the cumulative maximum in reverse.\n      inclusive (bool): The i-th element of the output includes the i-th\n        element of the input.\n\n    Returns:\n      array: The output array.\n    \"\"\"\n\ndef cummin(\n    a: array,\n    /,\n    axis: int | None = ...,\n    *,\n    reverse: bool = ...,\n    inclusive: bool = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Return the cumulative minimum of the elements along the given axis.\n\n    Args:\n      a (array): Input array\n      axis (int, optional): Optional axis to compute the cumulative minimum\n        over. If unspecified the cumulative minimum of the flattened array is\n        returned.\n      reverse (bool): Perform the cumulative minimum in reverse.\n      inclusive (bool): The i-th element of the output includes the i-th\n        element of the input.\n\n    Returns:\n      array: The output array.\n    \"\"\"\n\ndef cumprod(\n    a: array,\n    /,\n    axis: int | None = ...,\n    *,\n    reverse: bool = ...,\n    inclusive: bool = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Return the cumulative product of the elements along the given axis.\n\n    Args:\n      a (array): Input array\n      axis (int, optional): Optional axis to compute the cumulative product\n        over. If unspecified the cumulative product of the flattened array is\n        returned.\n      reverse (bool): Perform the cumulative product in reverse.\n      inclusive (bool): The i-th element of the output includes the i-th\n        element of the input.\n\n    Returns:\n      array: The output array.\n    \"\"\"\n\ndef cumsum(\n    a: array,\n    /,\n    axis: int | None = ...,\n    *,\n    reverse: bool = ...,\n    inclusive: bool = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Return the cumulative sum of the elements along the given axis.\n\n    Args:\n      a (array): Input array\n      axis (int, optional): Optional axis to compute the cumulative sum\n        over. If unspecified the cumulative sum of the flattened array is\n        returned.\n      reverse (bool): Perform the cumulative sum in reverse.\n      inclusive (bool): The i-th element of the output includes the i-th\n        element of the input.\n\n    Returns:\n      array: The output array.\n    \"\"\"\n\nclass custom_function:\n    \"\"\"\n    Set up a function for custom gradient and vmap definitions.\n\n    This class is meant to be used as a function decorator. Instances are\n    callables that behave identically to the wrapped function. However, when\n    a function transformation is used (e.g. computing gradients using\n    :func:`value_and_grad`) then the functions defined via\n    :meth:`custom_function.vjp`, :meth:`custom_function.jvp` and\n    :meth:`custom_function.vmap` are used instead of the default transformation.\n\n    Note, all custom transformations are optional. Undefined transformations\n    fall back to the default behaviour.\n\n    Example:\n\n      .. code-block:: python\n\n          import mlx.core as mx\n\n          @mx.custom_function\n          def f(x, y):\n              return mx.sin(x) * y\n\n          @f.vjp\n          def f_vjp(primals, cotangent, output):\n              x, y = primals\n              return cotan * mx.cos(x) * y, cotan * mx.sin(x)\n\n          @f.jvp\n          def f_jvp(primals, tangents):\n            x, y = primals\n            dx, dy = tangents\n            return dx * mx.cos(x) * y + dy * mx.sin(x)\n\n          @f.vmap\n          def f_vmap(inputs, axes):\n            x, y = inputs\n            ax, ay = axes\n            if ay != ax and ax is not None:\n                y = y.swapaxes(ay, ax)\n            return mx.sin(x) * y, (ax or ay)\n\n    All ``custom_function`` instances behave as pure functions. Namely, any\n    variables captured will be treated as constants and no gradients will be\n    computed with respect to the captured arrays. For instance:\n\n      .. code-block:: python\n\n        import mlx.core as mx\n\n        def g(x, y):\n          @mx.custom_function\n          def f(x):\n            return x * y\n\n          @f.vjp\n          def f_vjp(x, dx, fx):\n            # Note that we have only x, dx and fx and nothing with respect to y\n            raise ValueError(\"Abort!\")\n\n          return f(x)\n\n        x = mx.array(2.0)\n        y = mx.array(3.0)\n        print(g(x, y))                     # prints 6.0\n        print(mx.grad(g)(x, y))            # Raises exception\n        print(mx.grad(g, argnums=1)(x, y)) # prints 0.0\n    \"\"\"\n    def __init__(self, f: Callable) -> None: ...\n    def __call__(self, *args, **kwargs) -> object: ...\n    def vjp(self, f: Callable):\n        \"\"\"\n        Define a custom vjp for the wrapped function.\n\n        The vjp function takes three arguments:\n\n        - *primals*: A pytree that contains all the positional arguments to\n          the function. It could be a single array, a tuple of arrays or a\n          full blown tuple of dicts of arrays etc.\n        - *cotangents*: A pytree that matches the structure of the output\n          but contains the cotangents (usually the gradients of the loss\n          function with respect to the outputs).\n        - *outputs*: The outputs of the function to be used to avoid\n          recomputing them for the gradient computation.\n\n        The vjp function should return the same pytree structure as the\n        primals but containing the corresponding computed cotangents.\n        \"\"\"\n\n    def jvp(self, f: Callable):\n        \"\"\"\n        Define a custom jvp for the wrapped function.\n\n        The jvp function takes two arguments:\n\n        - *primals*: A pytree that contains all the positional arguments to\n          the function. It could be a single array, a tuple of arrays or a\n          full blown tuple of dicts of arrays etc.\n        - *tangents*: A pytree that matches the structure of the inputs but\n          instead contains the gradients wrt to each input. Tangents could\n          be ``None`` if some inputs don't have an associated gradient.\n\n        The jvp function should return the same pytree structure as the\n        outputs of the function but containing the tangents.\n        \"\"\"\n\n    def vmap(self, f: Callable):\n        \"\"\"\n        Define a custom vectorization transformation for the wrapped function.\n\n        The vmap function takes two arguments:\n\n        - *inputs*: A pytree that contains all the positional arguments to\n          the function. It could be a single array, a tuple of arrays or a\n          full blown tuple of dicts of arrays etc.\n        - *axes*: A pytree that matches the structure of the inputs but\n          instead contains the vectorization axis for each input or\n          ``None`` if an input is not vectorized.\n\n        The vmap function should return the outputs of the original\n        function but vectorized over the provided axes. It should also\n        return a pytree with the vectorization axes of each output. If some\n        outputs are no longer vectorized, then their vectorization axis\n        should be ``None``.\n        \"\"\"\n\ndef default_device() -> Device:\n    \"\"\"Get the default device.\"\"\"\n\ndef default_stream(device: Device | DeviceType) -> Stream:\n    \"\"\"Get the device's default stream.\"\"\"\n\ndef degrees(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Convert angles from radians to degrees.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The angles in degrees.\n    \"\"\"\n\ndef depends(inputs: array | Sequence[array], dependencies: array | Sequence[array]):\n    \"\"\"\n    Insert dependencies between arrays in the graph. The outputs are\n    identical to ``inputs`` but with dependencies on ``dependencies``.\n\n    Args:\n        inputs (array or Sequence[array]): The input array or arrays.\n        dependencies (array or Sequence[array]): The array or arrays\n          to insert dependencies on.\n\n    Returns:\n        array or Sequence[array]: The outputs which depend on dependencies.\n    \"\"\"\n\ndef dequantize(\n    w: array,\n    /,\n    scales: array,\n    biases: array | None = ...,\n    group_size: int = ...,\n    bits: int = ...,\n    mode: str = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    r\"\"\"\n    Dequantize the matrix ``w`` using quantization parameters.\n\n    Args:\n      w (array): Matrix to be dequantized\n      scales (array): The scales to use per ``group_size`` elements of ``w``.\n      biases (array, optional): The biases to use per ``group_size``\n         elements of ``w``. Default: ``None``.\n      group_size (int, optional): The size of the group in ``w`` that shares a\n        scale and bias. Default: ``64``.\n      bits (int, optional): The number of bits occupied by each element in\n        ``w``. Default: ``4``.\n      mode (str, optional): The quantization mode. Default: ``\"affine\"``.\n\n    Returns:\n      array: The dequantized version of ``w``\n\n    Notes:\n      The currently supported quantization modes are ``\"affine\"`` and ``mxfp4``.\n\n      For ``affine`` quantization, given the notation in :func:`quantize`,\n      we compute :math:`w_i` from :math:`\\hat{w_i}` and corresponding :math:`s`\n      and :math:`\\beta` as follows\n\n      .. math::\n\n        w_i = s \\hat{w_i} + \\beta\n    \"\"\"\n\ndef diag(a: array, /, k: int = ..., *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Extract a diagonal or construct a diagonal matrix.\n    If ``a`` is 1-D then a diagonal matrix is constructed with ``a`` on the\n    :math:`k`-th diagonal. If ``a`` is 2-D then the :math:`k`-th diagonal is\n    returned.\n\n    Args:\n        a (array): 1-D or 2-D input array.\n        k (int, optional): The diagonal to extract or construct.\n            Default: ``0``.\n\n    Returns:\n        array: The extracted diagonal or the constructed diagonal matrix.\n    \"\"\"\n\ndef diagonal(\n    a: array,\n    offset: int = ...,\n    axis1: int = ...,\n    axis2: int = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Return specified diagonals.\n\n    If ``a`` is 2-D, then a 1-D array containing the diagonal at the given\n    ``offset`` is returned.\n\n    If ``a`` has more than two dimensions, then ``axis1`` and ``axis2``\n    determine the 2D subarrays from which diagonals are extracted. The new\n    shape is the original shape with ``axis1`` and ``axis2`` removed and a\n    new dimension inserted at the end corresponding to the diagonal.\n\n    Args:\n      a (array): Input array\n      offset (int, optional): Offset of the diagonal from the main diagonal.\n        Can be positive or negative. Default: ``0``.\n      axis1 (int, optional): The first axis of the 2-D sub-arrays from which\n          the diagonals should be taken. Default: ``0``.\n      axis2 (int, optional): The second axis of the 2-D sub-arrays from which\n          the diagonals should be taken. Default: ``1``.\n\n    Returns:\n        array: The diagonals of the array.\n    \"\"\"\n\ndef disable_compile() -> None:\n    \"\"\"\n    Globally disable compilation. Setting the environment variable\n    ``MLX_DISABLE_COMPILE`` can also be used to disable compilation.\n    \"\"\"\n\ndef divide(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise division.\n\n    Divide two arrays with numpy-style broadcasting semantics. Either or both\n    input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The quotient ``a / b``.\n    \"\"\"\n\ndef divmod(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise quotient and remainder.\n\n    The fuction ``divmod(a, b)`` is equivalent to but faster than\n    ``(a // b, a % b)``. The function uses numpy-style broadcasting\n    semantics. Either or both input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        tuple(array, array): The quotient ``a // b`` and remainder ``a % b``.\n    \"\"\"\n\ne: float = ...\n\ndef einsum(subscripts: str, *operands, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Perform the Einstein summation convention on the operands.\n\n    Args:\n      subscripts (str): The Einstein summation convention equation.\n      *operands (array): The input arrays.\n\n    Returns:\n      array: The output array.\n    \"\"\"\n\ndef einsum_path(subscripts: str, *operands):\n    \"\"\"\n    Compute the contraction order for the given Einstein summation.\n\n    Args:\n      subscripts (str): The Einstein summation convention equation.\n      *operands (array): The input arrays.\n\n    Returns:\n      tuple(list(tuple(int, int)), str):\n        The einsum path and a string containing information about the\n        chosen path.\n    \"\"\"\n\ndef enable_compile() -> None:\n    \"\"\"\n    Globally enable compilation. This will override the environment\n    variable ``MLX_DISABLE_COMPILE`` if set.\n    \"\"\"\n\ndef equal(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise equality.\n\n    Equality comparison on two arrays with numpy-style broadcasting semantics.\n    Either or both input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The element-wise comparison ``a == b``.\n    \"\"\"\n\ndef erf(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    r\"\"\"\n    Element-wise error function.\n\n    .. math::\n      \\mathrm{erf}(x) = \\frac{2}{\\sqrt{\\pi}} \\int_0^x e^{-t^2} \\, dt\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The error function of ``a``.\n    \"\"\"\n\ndef erfinv(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise inverse of :func:`erf`.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The inverse error function of ``a``.\n    \"\"\"\n\neuler_gamma: float = ...\n\ntype MX_ARRAY_TREE = (\n    array\n    | Module\n    | list[MX_ARRAY_TREE]\n    | tuple[MX_ARRAY_TREE, ...]\n    | Mapping[str, MX_ARRAY_TREE]\n)\n\ndef eval(*args: MX_ARRAY_TREE | None) -> None:\n    \"\"\"\n    Evaluate an :class:`array` or tree of :class:`array`.\n\n    Args:\n        *args (arrays or trees of arrays): Each argument can be a single array\n          or a tree of arrays. If a tree is given the nodes can be a Python\n          :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not\n          arrays are ignored.\n    \"\"\"\n\ndef exp(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise exponential.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The exponential of ``a``.\n    \"\"\"\n\ndef expand_dims(\n    a: array,\n    /,\n    axis: int | Sequence[int],\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Add a size one dimension at the given axis.\n\n    Args:\n        a (array): Input array.\n        axes (int or tuple(int)): The index of the inserted dimensions.\n\n    Returns:\n        array: The array with inserted dimensions.\n    \"\"\"\n\ndef expm1(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise exponential minus 1.\n\n    Computes ``exp(x) - 1`` with greater precision for small ``x``.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The expm1 of ``a``.\n    \"\"\"\n\ndef export_function(\n    arg0: object, fun: Callable, *args, shapeless: bool = ..., **kwargs\n) -> None:\n    \"\"\"\n    Export an MLX function.\n\n    Example input arrays must be provided to export a function. The example\n    inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays\n    and/or dictionary of string keys with array values.\n\n    .. warning::\n\n      This is part of an experimental API which is likely to\n      change in future versions of MLX. Functions exported with older\n      versions of MLX may not be compatible with future versions.\n\n    Args:\n        file (str or Callable): Either a file path to export the function\n          to or a callback.\n        fun (Callable): A function which takes as input zero or more\n          :class:`array` and returns one or more :class:`array`.\n        *args (array): Example array inputs to the function.\n        shapeless (bool, optional): Whether or not the function allows\n          inputs with variable shapes. Default: ``False``.\n        **kwargs (array): Additional example keyword array inputs to the\n          function.\n\n    Example:\n\n      .. code-block:: python\n\n        def fun(x, y):\n            return x + y\n\n        x = mx.array(1)\n        y = mx.array([1, 2, 3])\n        mx.export_function(\"fun.mlxfn\", fun, x, y=y)\n    \"\"\"\n\ndef export_to_dot(file: object, *args, **kwargs) -> None:\n    \"\"\"\n    Export a graph to DOT format for visualization.\n\n    A variable number of output arrays can be provided for exporting\n    The graph exported will recursively include all unevaluated inputs of\n    the provided outputs.\n\n    Args:\n        file (str): The file path to export to.\n        *args (array): The output arrays.\n        **kwargs (dict[str, array]): Provide some names for arrays in the\n          graph to make the result easier to parse.\n\n    Example:\n      >>> a = mx.array(1) + mx.array(2)\n      >>> mx.export_to_dot(\"graph.dot\", a)\n      >>> x = mx.array(1)\n      >>> y = mx.array(2)\n      >>> mx.export_to_dot(\"graph.dot\", x + y, x=x, y=y)\n    \"\"\"\n\ndef exporter(file: str, fun: Callable, *, shapeless: bool = ...) -> FunctionExporter:\n    \"\"\"\n    Make a callable object to export multiple traces of a function to a file.\n\n    .. warning::\n\n      This is part of an experimental API which is likely to\n      change in future versions of MLX. Functions exported with older\n      versions of MLX may not be compatible with future versions.\n\n    Args:\n        file (str): File path to export the function to.\n        shapeless (bool, optional): Whether or not the function allows\n          inputs with variable shapes. Default: ``False``.\n\n    Example:\n\n      .. code-block:: python\n\n        def fun(*args):\n            return sum(args)\n\n        with mx.exporter(\"fun.mlxfn\", fun) as exporter:\n            exporter(mx.array(1))\n            exporter(mx.array(1), mx.array(2))\n            exporter(mx.array(1), mx.array(2), mx.array(3))\n    \"\"\"\n\ndef eye(\n    n: int,\n    m: int | None = ...,\n    k: int = ...,\n    dtype: Dtype | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Create an identity matrix or a general diagonal matrix.\n\n    Args:\n        n (int): The number of rows in the output.\n        m (int, optional): The number of columns in the output. Defaults to n.\n        k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal).\n        dtype (Dtype, optional): Data type of the output array. Defaults to float32.\n        stream (Stream, optional): Stream or device. Defaults to None.\n\n    Returns:\n        array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.\n    \"\"\"\n\nclass finfo:\n    \"\"\"Get information on floating-point types.\"\"\"\n    def __init__(self, arg: Dtype, /) -> None: ...\n    @property\n    def min(self) -> float:\n        \"\"\"The smallest representable number.\"\"\"\n\n    @property\n    def max(self) -> float:\n        \"\"\"The largest representable number.\"\"\"\n\n    @property\n    def eps(self) -> float:\n        \"\"\"\n        The difference between 1.0 and the next smallest\n        representable number larger than 1.0.\n        \"\"\"\n\n    @property\n    def dtype(self) -> Dtype:\n        \"\"\"The :obj:`Dtype`.\"\"\"\n\n    def __repr__(self) -> str: ...\n\ndef flatten(\n    a: array,\n    /,\n    start_axis: int = ...,\n    end_axis: int = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Flatten an array.\n\n    The axes flattened will be between ``start_axis`` and ``end_axis``,\n    inclusive. Negative axes are supported. After converting negative axis to\n    positive, axes outside the valid range will be clamped to a valid value,\n    ``start_axis`` to ``0`` and ``end_axis`` to ``ndim - 1``.\n\n    Args:\n        a (array): Input array.\n        start_axis (int, optional): The first dimension to flatten. Defaults to ``0``.\n        end_axis (int, optional): The last dimension to flatten. Defaults to ``-1``.\n        stream (Stream, optional): Stream or device. Defaults to ``None``\n          in which case the default stream of the default device is used.\n\n    Returns:\n        array: The flattened array.\n\n    Example:\n        >>> a = mx.array([[1, 2], [3, 4]])\n        >>> mx.flatten(a)\n        array([1, 2, 3, 4], dtype=int32)\n        >>>\n        >>> mx.flatten(a, start_axis=0, end_axis=-1)\n        array([1, 2, 3, 4], dtype=int32)\n    \"\"\"\n\nfloat16: Dtype = ...\nfloat32: Dtype = ...\nfloat64: Dtype = ...\nfloating: DtypeCategory = ...\n\ndef floor(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise floor.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The floor of ``a``.\n    \"\"\"\n\ndef floor_divide(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise integer division.\n\n    If either array is a floating point type then it is equivalent to\n    calling :func:`floor` after :func:`divide`.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The quotient ``a // b``.\n    \"\"\"\n\ndef full(\n    shape: int | Sequence[int],\n    vals: scalar | array,\n    dtype: Dtype | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Construct an array with the given value.\n\n    Constructs an array of size ``shape`` filled with ``vals``. If ``vals``\n    is an :obj:`array` it must be broadcastable to the given ``shape``.\n\n    Args:\n        shape (int or list(int)): The shape of the output array.\n        vals (float or int or array): Values to fill the array with.\n        dtype (Dtype, optional): Data type of the output array. If\n          unspecified the output type is inferred from ``vals``.\n\n    Returns:\n        array: The output array with the specified shape and values.\n    \"\"\"\n\ndef gather_mm(\n    a: array,\n    b: array,\n    /,\n    lhs_indices: array,\n    rhs_indices: array,\n    *,\n    sorted_indices: bool = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Matrix multiplication with matrix-level gather.\n\n    Performs a gather of the operands with the given indices followed by a\n    (possibly batched) matrix multiplication of two arrays.  This operation\n    is more efficient than explicitly applying a :func:`take` followed by a\n    :func:`matmul`.\n\n    The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices\n    along the batch dimensions (i.e. all but the last two dimensions) of\n    ``a`` and ``b`` respectively.\n\n    For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, ``lhs_indices``\n    contains indices from the range ``[0, A1 * A2 * ... * AS)``\n\n    For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices``\n    contains indices from the range ``[0, B1 * B2 * ... * BS)``\n\n    If only one index is passed and it is sorted, the ``sorted_indices``\n    flag can be passed for a possible faster implementation.\n\n    Args:\n        a (array): Input array.\n        b (array): Input array.\n        lhs_indices (array, optional): Integer indices for ``a``. Default: ``None``\n        rhs_indices (array, optional): Integer indices for ``b``. Default: ``None``\n        sorted_indices (bool, optional): May allow a faster implementation\n          if the passed indices are sorted. Default: ``False``.\n\n    Returns:\n        array: The output array.\n    \"\"\"\n\ndef gather_qmm(\n    x: array,\n    w: array,\n    /,\n    scales: array,\n    biases: array | None = ...,\n    lhs_indices: array | None = ...,\n    rhs_indices: array | None = ...,\n    transpose: bool = ...,\n    group_size: int = ...,\n    bits: int = ...,\n    mode: str = ...,\n    *,\n    sorted_indices: bool = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Perform quantized matrix multiplication with matrix-level gather.\n\n    This operation is the quantized equivalent to :func:`gather_mm`.\n    Similar to :func:`gather_mm`, the indices ``lhs_indices`` and\n    ``rhs_indices`` contain flat indices along the batch dimensions (i.e.\n    all but the last two dimensions) of ``x`` and ``w`` respectively.\n\n    Note that ``scales`` and ``biases`` must have the same batch dimensions\n    as ``w`` since they represent the same quantized matrix.\n\n    Args:\n        x (array): Input array\n        w (array): Quantized matrix packed in unsigned integers\n        scales (array): The scales to use per ``group_size`` elements of ``w``\n        biases (array, optional): The biases to use per ``group_size``\n          elements of ``w``. Default: ``None``.\n        lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``.\n        rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``.\n        transpose (bool, optional): Defines whether to multiply with the\n          transposed ``w`` or not, namely whether we are performing\n          ``x @ w.T`` or ``x @ w``. Default: ``True``.\n        group_size (int, optional): The size of the group in ``w`` that\n          shares a scale and bias. Default: ``64``.\n        bits (int, optional): The number of bits occupied by each element in\n          ``w``. Default: ``4``.\n        mode (str, optional): The quantization mode. Default: ``\"affine\"``.\n        sorted_indices (bool, optional): May allow a faster implementation\n          if the passed indices are sorted. Default: ``False``.\n\n    Returns:\n        array: The result of the multiplication of ``x`` with ``w``\n          after gathering using ``lhs_indices`` and ``rhs_indices``.\n    \"\"\"\n\ngeneric: DtypeCategory = ...\n\ndef get_active_memory() -> int:\n    \"\"\"\n    Get the actively used memory in bytes.\n\n    Note, this will not always match memory use reported by the system because\n    it does not include cached memory buffers.\n    \"\"\"\n\ndef get_cache_memory() -> int:\n    \"\"\"\n    Get the cache size in bytes.\n\n    The cache includes memory not currently used that has not been returned\n    to the system allocator.\n    \"\"\"\n\ndef get_peak_memory() -> int:\n    \"\"\"\n    Get the peak amount of used memory in bytes.\n\n    The maximum memory used recorded from the beginning of the program\n    execution or since the last call to :func:`reset_peak_memory`.\n    \"\"\"\n\ngpu: DeviceType = ...\n\ndef grad(\n    fun: Callable,\n    argnums: int | Sequence[int] | None = ...,\n    argnames: str | Sequence[str] = ...,\n) -> Callable:\n    \"\"\"\n    Returns a function which computes the gradient of ``fun``.\n\n    Args:\n        fun (Callable): A function which takes a variable number of\n          :class:`array` or trees of :class:`array` and returns\n          a scalar output :class:`array`.\n        argnums (int or list(int), optional): Specify the index (or indices)\n          of the positional arguments of ``fun`` to compute the gradient\n          with respect to. If neither ``argnums`` nor ``argnames`` are\n          provided ``argnums`` defaults to ``0`` indicating ``fun``'s first\n          argument.\n        argnames (str or list(str), optional): Specify keyword arguments of\n          ``fun`` to compute gradients with respect to. It defaults to [] so\n          no gradients for keyword arguments by default.\n\n    Returns:\n        Callable: A function which has the same input arguments as ``fun`` and\n        returns the gradient(s).\n    \"\"\"\n\ndef greater(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise greater than.\n\n    Strict greater than on two arrays with numpy-style broadcasting semantics.\n    Either or both input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The element-wise comparison ``a > b``.\n    \"\"\"\n\ndef greater_equal(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise greater or equal.\n\n    Greater than or equal on two arrays with numpy-style broadcasting semantics.\n    Either or both input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The element-wise comparison ``a >= b``.\n    \"\"\"\n\ndef hadamard_transform(\n    a: array, scale: float | None = ..., stream: Stream | Device | None = ...\n) -> array:\n    \"\"\"\n    Perform the Walsh-Hadamard transform along the final axis.\n\n    Equivalent to:\n\n    .. code-block:: python\n\n       from scipy.linalg import hadamard\n\n       y = (hadamard(len(x)) @ x) * scale\n\n    Supports sizes ``n = m*2^k`` for ``m`` in ``(1, 12, 20, 28)`` and ``2^k\n    <= 8192`` for float32 and ``2^k <= 16384`` for float16/bfloat16.\n\n    Args:\n        a (array): Input array or scalar.\n        scale (float): Scale the output by this factor.\n          Defaults to ``1/sqrt(a.shape[-1])`` so that the Hadamard matrix is orthonormal.\n\n    Returns:\n        array: The transformed array.\n    \"\"\"\n\ndef identity(\n    n: int, dtype: Dtype | None = ..., *, stream: Stream | Device | None = ...\n) -> array:\n    \"\"\"\n    Create a square identity matrix.\n\n    Args:\n        n (int): The number of rows and columns in the output.\n        dtype (Dtype, optional): Data type of the output array. Defaults to float32.\n        stream (Stream, optional): Stream or device. Defaults to None.\n\n    Returns:\n        array: An identity matrix of size n x n.\n    \"\"\"\n\nclass iinfo:\n    \"\"\"Get information on integer types.\"\"\"\n    def __init__(self, arg: Dtype, /) -> None: ...\n    @property\n    def min(self) -> int:\n        \"\"\"The smallest representable number.\"\"\"\n\n    @property\n    def max(self) -> int:\n        \"\"\"The largest representable number.\"\"\"\n\n    @property\n    def dtype(self) -> Dtype:\n        \"\"\"The :obj:`Dtype`.\"\"\"\n\n    def __repr__(self) -> str: ...\n\ndef imag(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Returns the imaginary part of a complex array.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The imaginary part of ``a``.\n    \"\"\"\n\ndef import_function(file: str) -> Callable:\n    \"\"\"\n    Import a function from a file.\n\n    The imported function can be called either with ``*args`` and\n    ``**kwargs`` or with a tuple of arrays and/or dictionary of string\n    keys with array values. Imported functions always return a tuple of\n    arrays.\n\n    .. warning::\n\n      This is part of an experimental API which is likely to\n      change in future versions of MLX. Functions exported with older\n      versions of MLX may not be compatible with future versions.\n\n    Args:\n        file (str): The file path to import the function from.\n\n    Returns:\n        Callable: The imported function.\n\n    Example:\n      >>> fn = mx.import_function(\"function.mlxfn\")\n      >>> out = fn(a, b, x=x, y=y)[0]\n      >>>\n      >>> out = fn((a, b), {\"x\": x, \"y\": y}[0]\n    \"\"\"\n\ninexact: DtypeCategory = ...\ninf: float = ...\n\ndef inner(a: array, b: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes.\n\n    Args:\n      a (array): Input array\n      b (array): Input array\n\n    Returns:\n      array: The inner product.\n    \"\"\"\n\nint16: Dtype = ...\nint32: Dtype = ...\nint64: Dtype = ...\nint8: Dtype = ...\ninteger: DtypeCategory = ...\n\ndef is_available(device: Device) -> bool:\n    \"\"\"Check if a back-end is available for the given device.\"\"\"\n\ndef isclose(\n    a: array,\n    b: array,\n    /,\n    rtol: float = ...,\n    atol: float = ...,\n    *,\n    equal_nan: bool = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Returns a boolean array where two arrays are element-wise equal within a tolerance.\n\n    Infinite values are considered equal if they have the same sign, NaN values are\n    not equal unless ``equal_nan`` is ``True``.\n\n    Two values are considered equal if:\n\n    .. code-block::\n\n     abs(a - b) <= (atol + rtol * abs(b))\n\n    Note unlike :func:`array_equal`, this function supports numpy-style\n    broadcasting.\n\n    Args:\n        a (array): Input array.\n        b (array): Input array.\n        rtol (float): Relative tolerance.\n        atol (float): Absolute tolerance.\n        equal_nan (bool): If ``True``, NaNs are considered equal.\n          Defaults to ``False``.\n\n    Returns:\n        array: The boolean output scalar indicating if the arrays are close.\n    \"\"\"\n\ndef isfinite(a: array, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Return a boolean array indicating which elements are finite.\n\n    An element is finite if it is not infinite or NaN.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The boolean array indicating which elements are finite.\n    \"\"\"\n\ndef isinf(a: array, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Return a boolean array indicating which elements are +/- inifnity.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The boolean array indicating which elements are +/- infinity.\n    \"\"\"\n\ndef isnan(a: array, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Return a boolean array indicating which elements are NaN.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The boolean array indicating which elements are NaN.\n    \"\"\"\n\ndef isneginf(a: array, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Return a boolean array indicating which elements are negative infinity.\n\n    Args:\n        a (array): Input array.\n        stream (Stream | Device | None): Optional stream or device.\n\n    Returns:\n        array: The boolean array indicating which elements are negative infinity.\n    \"\"\"\n\ndef isposinf(a: array, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Return a boolean array indicating which elements are positive infinity.\n\n    Args:\n        a (array): Input array.\n        stream (Stream | Device | None): Optional stream or device.\n\n    Returns:\n        array: The boolean array indicating which elements are positive infinity.\n    \"\"\"\n\ndef issubdtype(arg1: Dtype | DtypeCategory, arg2: Dtype | DtypeCategory) -> bool:\n    \"\"\"\n    Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype\n    of another.\n\n    Args:\n        arg1 (Dtype | DtypeCategory: First dtype or category.\n        arg2 (Dtype | DtypeCategory: Second dtype or category.\n\n    Returns:\n        bool:\n           A boolean indicating if the first input is a subtype of the\n           second input.\n\n    Example:\n\n      >>> ints = mx.array([1, 2, 3], dtype=mx.int32)\n      >>> mx.issubdtype(ints.dtype, mx.integer)\n      True\n      >>> mx.issubdtype(ints.dtype, mx.floating)\n      False\n\n      >>> floats = mx.array([1, 2, 3], dtype=mx.float32)\n      >>> mx.issubdtype(floats.dtype, mx.integer)\n      False\n      >>> mx.issubdtype(floats.dtype, mx.floating)\n      True\n\n      Similar types of different sizes are not subdtypes of each other:\n\n      >>> mx.issubdtype(mx.float64, mx.float32)\n      False\n      >>> mx.issubdtype(mx.float32, mx.float64)\n      False\n\n      but both are subtypes of `floating`:\n\n      >>> mx.issubdtype(mx.float64, mx.floating)\n      True\n      >>> mx.issubdtype(mx.float32, mx.floating)\n      True\n\n      For convenience, dtype-like objects are allowed too:\n\n      >>> mx.issubdtype(mx.float32, mx.inexact)\n      True\n      >>> mx.issubdtype(mx.signedinteger, mx.floating)\n      False\n    \"\"\"\n\ndef jvp(\n    fun: Callable, primals: list[array], tangents: list[array]\n) -> tuple[list[array], list[array]]:\n    \"\"\"\n    Compute the Jacobian-vector product.\n\n    This computes the product of the Jacobian of a function ``fun`` evaluated\n    at ``primals`` with the ``tangents``.\n\n    Args:\n        fun (Callable): A function which takes a variable number of :class:`array`\n          and returns a single :class:`array` or list of :class:`array`.\n        primals (list(array)): A list of :class:`array` at which to\n          evaluate the Jacobian.\n        tangents (list(array)): A list of :class:`array` which are the\n          \"vector\" in the Jacobian-vector product. The ``tangents`` should be the\n          same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).\n\n    Returns:\n        list(array): A list of the Jacobian-vector products which\n        is the same in number, shape, and type of the inputs to ``fun``.\n    \"\"\"\n\ndef kron(a: array, b: array, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Compute the Kronecker product of two arrays ``a`` and ``b``.\n\n    Args:\n      a (array): The first input array.\n      b (array): The second input array.\n      stream (Stream | Device | None, optional): Optional stream or\n        device for execution. Default: ``None``.\n\n    Returns:\n      array: The Kronecker product of ``a`` and ``b``.\n\n    Examples:\n      >>> a = mx.array([[1, 2], [3, 4]])\n      >>> b = mx.array([[0, 5], [6, 7]])\n      >>> result = mx.kron(a, b)\n      >>> print(result)\n      array([[0, 5, 0, 10],\n             [6, 7, 12, 14],\n             [0, 15, 0, 20],\n             [18, 21, 24, 28]], dtype=int32)\n    \"\"\"\n\ndef left_shift(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise left shift.\n\n    Shift the bits of the first input to the left by the second using\n    numpy-style broadcasting semantics. Either or both input arrays can\n    also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The bitwise left shift ``a << b``.\n    \"\"\"\n\ndef less(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise less than.\n\n    Strict less than on two arrays with numpy-style broadcasting semantics.\n    Either or both input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The element-wise comparison ``a < b``.\n    \"\"\"\n\ndef less_equal(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise less than or equal.\n\n    Less than or equal on two arrays with numpy-style broadcasting semantics.\n    Either or both input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The element-wise comparison ``a <= b``.\n    \"\"\"\n\ndef linspace(\n    start,\n    stop,\n    num: int | None = ...,\n    dtype: Dtype | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Generate ``num`` evenly spaced numbers over interval ``[start, stop]``.\n\n    Args:\n        start (scalar): Starting value.\n        stop (scalar): Stopping value.\n        num (int, optional): Number of samples, defaults to ``50``.\n        dtype (Dtype, optional): Specifies the data type of the output,\n          default to ``float32``.\n\n    Returns:\n        array: The range of values.\n    \"\"\"\n\ndef load(\n    file: str | pathlib.Path,\n    /,\n    format: str | None = ...,\n    return_metadata: bool = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array | dict[str, array]:\n    \"\"\"\n    Load array(s) from a binary file.\n\n    The supported formats are ``.npy``, ``.npz``, ``.safetensors``, and\n    ``.gguf``.\n\n    Args:\n        file (str, pathlib.Path): File in which the array is saved.\n        format (str, optional): Format of the file. If ``None``, the\n          format is inferred from the file extension. Supported formats:\n          ``npy``, ``npz``, and ``safetensors``. Default: ``None``.\n        return_metadata (bool, optional): Load the metadata for formats\n          which support matadata. The metadata will be returned as an\n          additional dictionary. Default: ``False``.\n    Returns:\n        array or dict:\n            A single array if loading from a ``.npy`` file or a dict\n            mapping names to arrays if loading from a ``.npz`` or\n            ``.safetensors`` file. If ``return_metadata`` is ``True`` an\n            additional dictionary of metadata will be returned.\n\n    Warning:\n\n      When loading unsupported quantization formats from GGUF, tensors\n      will automatically cast to ``mx.float16``\n    \"\"\"\n\ndef log(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise natural logarithm.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The natural logarithm of ``a``.\n    \"\"\"\n\ndef log10(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise base-10 logarithm.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The base-10 logarithm of ``a``.\n    \"\"\"\n\ndef log1p(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise natural log of one plus the array.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The natural logarithm of one plus ``a``.\n    \"\"\"\n\ndef log2(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise base-2 logarithm.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The base-2 logarithm of ``a``.\n    \"\"\"\n\ndef logaddexp(\n    a: scalar | array,\n    b: scalar | array,\n    /,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise log-add-exp.\n\n    This is a numerically stable log-add-exp of two arrays with numpy-style\n    broadcasting semantics. Either or both input arrays can also be scalars.\n\n    The computation is is a numerically stable version of ``log(exp(a) + exp(b))``.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The log-add-exp of ``a`` and ``b``.\n    \"\"\"\n\ndef logcumsumexp(\n    a: array,\n    /,\n    axis: int | None = ...,\n    *,\n    reverse: bool = ...,\n    inclusive: bool = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Return the cumulative logsumexp of the elements along the given axis.\n\n    Args:\n      a (array): Input array\n      axis (int, optional): Optional axis to compute the cumulative logsumexp\n        over. If unspecified the cumulative logsumexp of the flattened array is\n        returned.\n      reverse (bool): Perform the cumulative logsumexp in reverse.\n      inclusive (bool): The i-th element of the output includes the i-th\n        element of the input.\n\n    Returns:\n      array: The output array.\n    \"\"\"\n\ndef logical_and(\n    a: array, b: array, /, *, stream: Stream | Device | None = ...\n) -> array:\n    \"\"\"\n    Element-wise logical and.\n\n    Args:\n        a (array): First input array or scalar.\n        b (array): Second input array or scalar.\n\n    Returns:\n        array: The boolean array containing the logical and of ``a`` and ``b``.\n    \"\"\"\n\ndef logical_not(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise logical not.\n\n    Args:\n        a (array): Input array or scalar.\n\n    Returns:\n        array: The boolean array containing the logical not of ``a``.\n    \"\"\"\n\ndef logical_or(a: array, b: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise logical or.\n\n    Args:\n        a (array): First input array or scalar.\n        b (array): Second input array or scalar.\n\n    Returns:\n        array: The boolean array containing the logical or of ``a`` and ``b``.\n    \"\"\"\n\ndef logsumexp(\n    a: array,\n    /,\n    axis: int | Sequence[int] | None = ...,\n    keepdims: bool = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    A `log-sum-exp` reduction over the given axes.\n\n    The log-sum-exp reduction is a numerically stable version of:\n\n    .. code-block::\n\n      log(sum(exp(a), axis))\n\n    Args:\n        a (array): Input array.\n        axis (int or list(int), optional): Optional axis or\n          axes to reduce over. If unspecified this defaults\n          to reducing over the entire array.\n        keepdims (bool, optional): Keep reduced axes as\n          singleton dimensions, defaults to `False`.\n\n    Returns:\n        array: The output array with the corresponding axes reduced.\n    \"\"\"\n\ndef matmul(a: array, b: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Matrix multiplication.\n\n    Perform the (possibly batched) matrix multiplication of two arrays. This function supports\n    broadcasting for arrays with more than two dimensions.\n\n    - If the first array is 1-D then a 1 is prepended to its shape to make it\n      a matrix. Similarly if the second array is 1-D then a 1 is appended to its\n      shape to make it a matrix. In either case the singleton dimension is removed\n      from the result.\n    - A batched matrix multiplication is performed if the arrays have more than\n      2 dimensions.  The matrix dimensions for the matrix product are the last\n      two dimensions of each input.\n    - All but the last two dimensions of each input are broadcast with one another using\n      standard numpy-style broadcasting semantics.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The matrix product of ``a`` and ``b``.\n    \"\"\"\n\ndef max(\n    a: array,\n    /,\n    axis: int | Sequence[int] | None = ...,\n    keepdims: bool = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    A `max` reduction over the given axes.\n\n    Args:\n        a (array): Input array.\n        axis (int or list(int), optional): Optional axis or\n          axes to reduce over. If unspecified this defaults\n          to reducing over the entire array.\n        keepdims (bool, optional): Keep reduced axes as\n          singleton dimensions, defaults to `False`.\n\n    Returns:\n        array: The output array with the corresponding axes reduced.\n    \"\"\"\n\ndef maximum(\n    a: scalar | array,\n    b: scalar | array,\n    /,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise maximum.\n\n    Take the element-wise max of two arrays with numpy-style broadcasting\n    semantics. Either or both input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The max of ``a`` and ``b``.\n    \"\"\"\n\ndef mean(\n    a: array,\n    /,\n    axis: int | Sequence[int] | None = ...,\n    keepdims: bool = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Compute the mean(s) over the given axes.\n\n    Args:\n        a (array): Input array.\n        axis (int or list(int), optional): Optional axis or\n          axes to reduce over. If unspecified this defaults\n          to reducing over the entire array.\n        keepdims (bool, optional): Keep reduced axes as\n          singleton dimensions, defaults to `False`.\n\n    Returns:\n        array: The output array of means.\n    \"\"\"\n\ndef meshgrid(\n    *arrays: array,\n    sparse: bool | None = ...,\n    indexing: str | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Generate multidimensional coordinate grids from 1-D coordinate arrays\n\n    Args:\n        *arrays (array): Input arrays.\n        sparse (bool, optional): If ``True``, a sparse grid is returned in which each output\n          array has a single non-zero element. If ``False``, a dense grid is returned.\n          Defaults to ``False``.\n        indexing (str, optional): Cartesian ('xy') or matrix ('ij') indexing of the output arrays.\n          Defaults to ``'xy'``.\n\n    Returns:\n        list(array): The output arrays.\n    \"\"\"\n\ndef min(\n    a: array,\n    /,\n    axis: int | Sequence[int] | None = ...,\n    keepdims: bool = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    A `min` reduction over the given axes.\n\n    Args:\n        a (array): Input array.\n        axis (int or list(int), optional): Optional axis or\n          axes to reduce over. If unspecified this defaults\n          to reducing over the entire array.\n        keepdims (bool, optional): Keep reduced axes as\n          singleton dimensions, defaults to `False`.\n\n    Returns:\n        array: The output array with the corresponding axes reduced.\n    \"\"\"\n\ndef minimum(\n    a: scalar | array,\n    b: scalar | array,\n    /,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise minimum.\n\n    Take the element-wise min of two arrays with numpy-style broadcasting\n    semantics. Either or both input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The min of ``a`` and ``b``.\n    \"\"\"\n\ndef moveaxis(\n    a: array,\n    /,\n    source: int,\n    destination: int,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Move an axis to a new position.\n\n    Args:\n        a (array): Input array.\n        source (int): Specifies the source axis.\n        destination (int): Specifies the destination axis.\n\n    Returns:\n        array: The array with the axis moved.\n    \"\"\"\n\ndef multiply(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise multiplication.\n\n    Multiply two arrays with numpy-style broadcasting semantics. Either or both\n    input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The multiplication ``a * b``.\n    \"\"\"\n\nnan: float = ...\n\ndef nan_to_num(\n    a: scalar | array,\n    nan: float = ...,\n    posinf: float | None = ...,\n    neginf: float | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Replace NaN and Inf values with finite numbers.\n\n    Args:\n        a (array): Input array\n        nan (float, optional): Value to replace NaN with. Default: ``0``.\n        posinf (float, optional): Value to replace positive infinities\n          with. If ``None``, defaults to largest finite value for the\n          given data type. Default: ``None``.\n        neginf (float, optional): Value to replace negative infinities\n          with. If ``None``, defaults to the negative of the largest\n          finite value for the given data type. Default: ``None``.\n\n    Returns:\n        array: Output array with NaN and Inf replaced.\n    \"\"\"\n\ndef negative(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise negation.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The negative of ``a``.\n    \"\"\"\n\ndef new_stream(device: Device) -> Stream:\n    \"\"\"Make a new stream on the given device.\"\"\"\n\nnewaxis: None = ...\n\ndef not_equal(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise not equal.\n\n    Not equal comparison on two arrays with numpy-style broadcasting semantics.\n    Either or both input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The element-wise comparison ``a != b``.\n    \"\"\"\n\nnumber: DtypeCategory = ...\n\ndef ones(\n    shape: int | Sequence[int],\n    dtype: Dtype | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Construct an array of ones.\n\n    Args:\n        shape (int or list(int)): The shape of the output array.\n        dtype (Dtype, optional): Data type of the output array. If\n          unspecified the output type defaults to ``float32``.\n\n    Returns:\n        array: The array of ones with the specified shape.\n    \"\"\"\n\ndef ones_like(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    An array of ones like the input.\n\n    Args:\n        a (array): The input to take the shape and type from.\n\n    Returns:\n        array: The output array filled with ones.\n    \"\"\"\n\ndef outer(a: array, b: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Compute the outer product of two 1-D arrays, if the array's passed are not 1-D a flatten op will be run beforehand.\n\n    Args:\n      a (array): Input array\n      b (array): Input array\n\n    Returns:\n      array: The outer product.\n    \"\"\"\n\ndef pad(\n    a: array,\n    pad_width: int | tuple[int] | tuple[int, int] | list[tuple[int, int]],\n    mode: Literal[\"constant\", \"edge\"] = ...,\n    constant_values: scalar | array = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Pad an array with a constant value\n\n    Args:\n        a (array): Input array.\n        pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded\n          values to add to the edges of each axis:``((before_1, after_1),\n          (before_2, after_2), ..., (before_N, after_N))``. If a single pair\n          of integers is passed then ``(before_i, after_i)`` are all the same.\n          If a single integer or tuple with a single integer is passed then\n          all axes are extended by the same number on each side.\n        mode: Padding mode. One of the following strings:\n          \"constant\" (default): Pads with a constant value.\n          \"edge\": Pads with the edge values of array.\n        constant_value (array or scalar, optional): Optional constant value\n          to pad the edges of the array with.\n\n    Returns:\n        array: The padded array.\n    \"\"\"\n\ndef partition(\n    a: array,\n    /,\n    kth: int,\n    axis: int | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Returns a partitioned copy of the array such that the smaller ``kth``\n    elements are first.\n\n    The ordering of the elements in partitions is undefined.\n\n    Args:\n        a (array): Input array.\n        kth (int): Element at the ``kth`` index will be in its sorted\n          position in the output. All elements before the kth index will\n          be less or equal to the ``kth`` element and all elements after\n          will be greater or equal to the ``kth`` element in the output.\n        axis (int or None, optional): Optional axis to partition over.\n          If ``None``, this partitions over the flattened array.\n          If unspecified, it defaults to ``-1``.\n\n    Returns:\n        array: The partitioned array.\n    \"\"\"\n\ndef permute_dims(\n    a: array,\n    /,\n    axes: Sequence[int] | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"See :func:`transpose`.\"\"\"\n\npi: float = ...\n\ndef power(\n    a: scalar | array,\n    b: scalar | array,\n    /,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise power operation.\n\n    Raise the elements of a to the powers in elements of b with numpy-style\n    broadcasting semantics. Either or both input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: Bases of ``a`` raised to powers in ``b``.\n    \"\"\"\n\ndef prod(\n    a: array,\n    /,\n    axis: int | Sequence[int] | None = ...,\n    keepdims: bool = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    An product reduction over the given axes.\n\n    Args:\n        a (array): Input array.\n        axis (int or list(int), optional): Optional axis or\n          axes to reduce over. If unspecified this defaults\n          to reducing over the entire array.\n        keepdims (bool, optional): Keep reduced axes as\n          singleton dimensions, defaults to `False`.\n\n    Returns:\n        array: The output array with the corresponding axes reduced.\n    \"\"\"\n\ndef put_along_axis(\n    a: array,\n    /,\n    indices: array,\n    values: array,\n    axis: int | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Put values along an axis at the specified indices.\n\n    Args:\n        a (array): Destination array.\n        indices (array): Indices array. These should be broadcastable with\n          the input array excluding the `axis` dimension.\n        values (array): Values array. These should be broadcastable with\n          the indices.\n\n        axis (int or None): Axis in the destination to put the values to. If\n          ``axis == None`` the destination is flattened prior to the put\n          operation.\n\n    Returns:\n        array: The output array.\n    \"\"\"\n\ndef quantize(\n    w: array,\n    /,\n    group_size: int = ...,\n    bits: int = ...,\n    mode: str = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> tuple[array, array, array]:\n    r\"\"\"\n    Quantize the matrix ``w`` using ``bits`` bits per element.\n\n    Note, every ``group_size`` elements in a row of ``w`` are quantized\n    together. Hence, number of columns of ``w`` should be divisible by\n    ``group_size``. In particular, the rows of ``w`` are divided into groups of\n    size ``group_size`` which are quantized together.\n\n    .. warning::\n\n      ``quantize`` currently only supports 2D inputs with the second\n      dimension divisible by ``group_size``\n\n    The supported quantization modes are ``\"affine\"`` and ``\"mxfp4\"``. They\n    are described in more detail below.\n\n    Args:\n      w (array): Matrix to be quantized\n      group_size (int, optional): The size of the group in ``w`` that shares a\n        scale and bias. Default: ``64``.\n      bits (int, optional): The number of bits occupied by each element of\n        ``w`` in the returned quantized matrix. Default: ``4``.\n      mode (str, optional): The quantization mode. Default: ``\"affine\"``.\n\n    Returns:\n      tuple: A tuple with either two or three elements containing:\n\n      * w_q (array): The quantized version of ``w``\n      * scales (array): The quantization scales\n      * biases (array): The quantization biases (returned for ``mode==\"affine\"``).\n\n    Notes:\n      The ``affine`` mode quantizes groups of :math:`g` consecutive\n      elements in a row of ``w``. For each group the quantized\n      representation of each element :math:`\\hat{w_i}` is computed as follows:\n\n      .. math::\n\n        \\begin{aligned}\n          \\alpha &= \\max_i w_i \\\\\n          \\beta &= \\min_i w_i \\\\\n          s &= \\frac{\\alpha - \\beta}{2^b - 1} \\\\\n          \\hat{w_i} &= \\textrm{round}\\left( \\frac{w_i - \\beta}{s}\\right).\n        \\end{aligned}\n\n      After the above computation, :math:`\\hat{w_i}` fits in :math:`b` bits\n      and is packed in an unsigned 32-bit integer from the lower to upper\n      bits. For instance, for 4-bit quantization we fit 8 elements in an\n      unsigned 32 bit integer where the 1st element occupies the 4 least\n      significant bits, the 2nd bits 4-7 etc.\n\n      To dequantize the elements of ``w``, we also save :math:`s` and\n      :math:`\\beta` which are the returned ``scales`` and\n      ``biases`` respectively.\n\n      The ``mxfp4`` mode similarly quantizes groups of :math:`g` elements\n      of ``w``. For ``mxfp4`` the group size must be ``32``. The elements\n      are quantized to 4-bit precision floating-point values (E2M1) with a\n      shared 8-bit scale per group. Unlike ``affine`` quantization,\n      ``mxfp4`` does not have a bias value. More details on the format can\n      be found in the `specification <https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf>`_.\n    \"\"\"\n\ndef quantized_matmul(\n    x: array,\n    w: array,\n    /,\n    scales: array,\n    biases: array | None = ...,\n    transpose: bool = ...,\n    group_size: int = ...,\n    bits: int = ...,\n    mode: str = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Perform the matrix multiplication with the quantized matrix ``w``. The\n    quantization uses one floating point scale and bias per ``group_size`` of\n    elements. Each element in ``w`` takes ``bits`` bits and is packed in an\n    unsigned 32 bit integer.\n\n    Args:\n      x (array): Input array\n      w (array): Quantized matrix packed in unsigned integers\n      scales (array): The scales to use per ``group_size`` elements of ``w``\n      biases (array, optional): The biases to use per ``group_size``\n        elements of ``w``. Default: ``None``.\n      transpose (bool, optional): Defines whether to multiply with the\n        transposed ``w`` or not, namely whether we are performing\n        ``x @ w.T`` or ``x @ w``. Default: ``True``.\n      group_size (int, optional): The size of the group in ``w`` that\n        shares a scale and bias. Default: ``64``.\n      bits (int, optional): The number of bits occupied by each element in\n        ``w``. Default: ``4``.\n      mode (str, optional): The quantization mode. Default: ``\"affine\"``.\n\n    Returns:\n      array: The result of the multiplication of ``x`` with ``w``.\n    \"\"\"\n\ndef radians(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Convert angles from degrees to radians.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The angles in radians.\n    \"\"\"\n\ndef real(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Returns the real part of a complex array.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The real part of ``a``.\n    \"\"\"\n\ndef reciprocal(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise reciprocal.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The reciprocal of ``a``.\n    \"\"\"\n\ndef remainder(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise remainder of division.\n\n    Computes the remainder of dividing a with b with numpy-style\n    broadcasting semantics. Either or both input arrays can also be\n    scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The remainder of ``a // b``.\n    \"\"\"\n\ndef repeat(\n    array: array,\n    repeats: int,\n    axis: int | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Repeat an array along a specified axis.\n\n    Args:\n        array (array): Input array.\n        repeats (int): The number of repetitions for each element.\n        axis (int, optional): The axis in which to repeat the array along. If\n          unspecified it uses the flattened array of the input and repeats\n          along axis 0.\n        stream (Stream, optional): Stream or device. Defaults to ``None``.\n\n    Returns:\n        array: The resulting repeated array.\n    \"\"\"\n\ndef reset_peak_memory() -> None:\n    \"\"\"Reset the peak memory to zero.\"\"\"\n\ndef reshape(\n    a: array, /, shape: Sequence[int], *, stream: Stream | Device | None = ...\n) -> array:\n    \"\"\"\n    Reshape an array while preserving the size.\n\n    Args:\n        a (array): Input array.\n        shape (tuple(int)): New shape.\n        stream (Stream, optional): Stream or device. Defaults to ``None``\n          in which case the default stream of the default device is used.\n\n    Returns:\n        array: The reshaped array.\n    \"\"\"\n\ndef right_shift(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise right shift.\n\n    Shift the bits of the first input to the right by the second using\n    numpy-style broadcasting semantics. Either or both input arrays can\n    also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The bitwise right shift ``a >> b``.\n    \"\"\"\n\ndef roll(\n    a: array,\n    shift: int | tuple[int],\n    axis: int | tuple[int] | None = ...,\n    /,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Roll array elements along a given axis.\n\n    Elements that are rolled beyond the end of the array are introduced at\n    the beggining and vice-versa.\n\n    If the axis is not provided the array is flattened, rolled and then the\n    shape is restored.\n\n    Args:\n      a (array): Input array\n      shift (int or tuple(int)): The number of places by which elements\n        are shifted. If positive the array is rolled to the right, if\n        negative it is rolled to the left. If an int is provided but the\n        axis is a tuple then the same value is used for all axes.\n      axis (int or tuple(int), optional): The axis or axes along which to\n        roll the elements.\n    \"\"\"\n\ndef round(\n    a: array, /, decimals: int = ..., stream: Stream | Device | None = ...\n) -> array:\n    \"\"\"\n    Round to the given number of decimals.\n\n    Basically performs:\n\n    .. code-block:: python\n\n      s = 10**decimals\n      x = round(x * s) / s\n\n    Args:\n      a (array): Input array\n      decimals (int): Number of decimal places to round to. (default: 0)\n\n    Returns:\n      array: An array of the same type as ``a`` rounded to the\n      given number of decimals.\n    \"\"\"\n\ndef rsqrt(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise reciprocal and square root.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: One over the square root of ``a``.\n    \"\"\"\n\ndef save(file: str | pathlib.Path, arr: array) -> None:\n    \"\"\"\n    Save the array to a binary file in ``.npy`` format.\n\n    Args:\n        file (str, pathlib.Path): File to which the array is saved\n        arr (array): Array to be saved.\n    \"\"\"\n\ndef save_gguf(\n    file: str | pathlib.Path,\n    arrays: dict[str, array],\n    metadata: dict[str, array | str | list[str]],\n):\n    \"\"\"\n    Save array(s) to a binary file in ``.gguf`` format.\n\n    See the `GGUF documentation\n    <https://github.com/ggerganov/ggml/blob/master/docs/gguf.md>`_ for\n    more information on the format.\n\n    Args:\n        file (file, str, pathlib.Path): File in which the array is saved.\n        arrays (dict(str, array)): The dictionary of names to arrays to\n          be saved.\n        metadata (dict(str, array | str | list(str))): The dictionary\n           of metadata to be saved. The values can be a scalar or 1D\n           obj:`array`, a :obj:`str`, or a :obj:`list` of :obj:`str`.\n    \"\"\"\n\ndef save_safetensors(\n    file: str | pathlib.Path,\n    arrays: dict[str, array],\n    metadata: dict[str, str] | None = ...,\n):\n    \"\"\"\n    Save array(s) to a binary file in ``.safetensors`` format.\n\n    See the `Safetensors documentation\n    <https://huggingface.co/docs/safetensors/index>`_ for more\n    information on the format.\n\n    Args:\n        file (file, str, pathlib.Path): File in which the array is saved.\n        arrays (dict(str, array)): The dictionary of names to arrays to\n          be saved.\n        metadata (dict(str, str), optional): The dictionary of\n          metadata to be saved.\n    \"\"\"\n\ndef savez(file: str | pathlib.Path, *args, **kwargs):\n    \"\"\"\n    Save several arrays to a binary file in uncompressed ``.npz``\n    format.\n\n    .. code-block:: python\n\n        import mlx.core as mx\n\n        x = mx.ones((10, 10))\n        mx.savez(\"my_path.npz\", x=x)\n\n        import mlx.nn as nn\n        from mlx.utils import tree_flatten\n\n        model = nn.TransformerEncoder(6, 128, 4)\n        flat_params = tree_flatten(model.parameters())\n        mx.savez(\"model.npz\", **dict(flat_params))\n\n    Args:\n        file (file, str, pathlib.Path): Path to file to which the arrays are saved.\n        *args (arrays): Arrays to be saved.\n        **kwargs (arrays): Arrays to be saved. Each array will be saved\n          with the associated keyword as the output file name.\n    \"\"\"\n\ndef savez_compressed(file: str | pathlib.Path, *args, **kwargs):\n    \"\"\"\n    Save several arrays to a binary file in compressed ``.npz`` format.\n\n    Args:\n        file (file, str, pathlib.Path): Path to file to which the arrays are saved.\n        *args (arrays): Arrays to be saved.\n        **kwargs (arrays): Arrays to be saved. Each array will be saved\n          with the associated keyword as the output file name.\n    \"\"\"\n\ndef segmented_mm(\n    a: array, b: array, /, segments: array, *, stream: Stream | Device | None = ...\n) -> array:\n    \"\"\"\n    Perform a matrix multiplication but segment the inner dimension and\n    save the result for each segment separately.\n\n    Args:\n      a (array): Input array of shape ``MxK``.\n      b (array): Input array of shape ``KxN``.\n      segments (array): The offsets into the inner dimension for each segment.\n\n    Returns:\n      array: The result per segment of shape ``MxN``.\n    \"\"\"\n\ndef set_cache_limit(limit: int) -> int:\n    \"\"\"\n    Set the free cache limit.\n\n    If using more than the given limit, free memory will be reclaimed\n    from the cache on the next allocation. To disable the cache, set\n    the limit to ``0``.\n\n    The cache limit defaults to the memory limit. See\n    :func:`set_memory_limit` for more details.\n\n    Args:\n      limit (int): The cache limit in bytes.\n\n    Returns:\n      int: The previous cache limit in bytes.\n    \"\"\"\n\ndef set_default_device(device: Device | DeviceType) -> None:\n    \"\"\"Set the default device.\"\"\"\n\ndef set_default_stream(stream: Stream) -> None:\n    \"\"\"\n    Set the default stream.\n\n    This will make the given stream the default for the\n    streams device. It will not change the default device.\n\n    Args:\n      stream (stream): Stream to make the default.\n    \"\"\"\n\ndef set_memory_limit(limit: int) -> int:\n    \"\"\"\n    Set the memory limit.\n\n    The memory limit is a guideline for the maximum amount of memory to use\n    during graph evaluation. If the memory limit is exceeded and there is no\n    more RAM (including swap when available) allocations will result in an\n    exception.\n\n    When metal is available the memory limit defaults to 1.5 times the\n    maximum recommended working set size reported by the device.\n\n    Args:\n      limit (int): Memory limit in bytes.\n\n    Returns:\n      int: The previous memory limit in bytes.\n    \"\"\"\n\ndef set_wired_limit(limit: int) -> int:\n    \"\"\"\n    Set the wired size limit.\n\n    .. note::\n       * This function is only useful on macOS 15.0 or higher.\n       * The wired limit should remain strictly less than the total\n         memory size.\n\n    The wired limit is the total size in bytes of memory that will be kept\n    resident. The default value is ``0``.\n\n    Setting a wired limit larger than system wired limit is an error. You can\n    increase the system wired limit with:\n\n    .. code-block::\n\n      sudo sysctl iogpu.wired_limit_mb=<size_in_megabytes>\n\n    Use :func:`device_info` to query the system wired limit\n    (``\"max_recommended_working_set_size\"``) and the total memory size\n    (``\"memory_size\"``).\n\n    Args:\n      limit (int): The wired limit in bytes.\n\n    Returns:\n      int: The previous wired limit in bytes.\n    \"\"\"\n\ndef sigmoid(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    r\"\"\"\n    Element-wise logistic sigmoid.\n\n    The logistic sigmoid function is:\n\n    .. math::\n      \\mathrm{sigmoid}(x) = \\frac{1}{1 + e^{-x}}\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The logistic sigmoid of ``a``.\n    \"\"\"\n\ndef sign(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise sign.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The sign of ``a``.\n    \"\"\"\n\nsignedinteger: DtypeCategory = ...\n\ndef sin(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise sine.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The sine of ``a``.\n    \"\"\"\n\ndef sinh(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise hyperbolic sine.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The hyperbolic sine of ``a``.\n    \"\"\"\n\ndef slice(\n    a: array,\n    start_indices: array,\n    axes: Sequence[int],\n    slice_size: Sequence[int],\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Extract a sub-array from the input array.\n\n    Args:\n      a (array): Input array\n      start_indices (array): The index location to start the slice at.\n      axes (tuple(int)): The axes corresponding to the indices in ``start_indices``.\n      slice_size (tuple(int)): The size of the slice.\n\n    Returns:\n      array: The sliced output array.\n\n    Example:\n\n      >>> a = mx.array([[1, 2, 3], [4, 5, 6]])\n      >>> mx.slice(a, start_indices=mx.array(1), axes=(0,), slice_size=(1, 2))\n      array([[4, 5]], dtype=int32)\n      >>>\n      >>> mx.slice(a, start_indices=mx.array(1), axes=(1,), slice_size=(2, 1))\n      array([[2],\n             [5]], dtype=int32)\n    \"\"\"\n\ndef slice_update(\n    a: array,\n    update: array,\n    start_indices: array,\n    axes: Sequence[int],\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Update a sub-array of the input array.\n\n    Args:\n      a (array): The input array to update\n      update (array): The update array.\n      start_indices (array): The index location to start the slice at.\n      axes (tuple(int)): The axes corresponding to the indices in ``start_indices``.\n\n    Returns:\n      array: The output array with the same shape and type as the input.\n\n    Example:\n\n      >>> a = mx.zeros((3, 3))\n      >>> mx.slice_update(a, mx.ones((1, 2)), start_indices=mx.array(1, 1), axes=(0, 1))\n      array([[0, 0, 0],\n             [0, 1, 0],\n             [0, 1, 0]], dtype=float32)\n    \"\"\"\n\ndef softmax(\n    a: array,\n    /,\n    axis: int | Sequence[int] | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Perform the softmax along the given axis.\n\n    This operation is a numerically stable version of:\n\n    .. code-block::\n\n      exp(a) / sum(exp(a), axis, keepdims=True)\n\n    Args:\n        a (array): Input array.\n        axis (int or list(int), optional): Optional axis or axes to compute\n         the softmax over. If unspecified this performs the softmax over\n         the full array.\n\n    Returns:\n        array: The output of the softmax.\n    \"\"\"\n\ndef sort(\n    a: array,\n    /,\n    axis: int | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Returns a sorted copy of the array.\n\n    Args:\n        a (array): Input array.\n        axis (int or None, optional): Optional axis to sort over.\n          If ``None``, this sorts over the flattened array.\n          If unspecified, it defaults to -1 (sorting over the last axis).\n\n    Returns:\n        array: The sorted array.\n    \"\"\"\n\ndef split(\n    a: array,\n    /,\n    indices_or_sections: int | Sequence[int],\n    axis: int = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Split an array along a given axis.\n\n    Args:\n        a (array): Input array.\n        indices_or_sections (int or list(int)): If ``indices_or_sections``\n          is an integer the array is split into that many sections of equal\n          size. An error is raised if this is not possible. If ``indices_or_sections``\n          is a list, the list contains the indices of the start of each subarray\n          along the given axis.\n        axis (int, optional): Axis to split along, defaults to `0`.\n\n    Returns:\n        list(array): A list of split arrays.\n    \"\"\"\n\ndef sqrt(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise square root.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The square root of ``a``.\n    \"\"\"\n\ndef square(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise square.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The square of ``a``.\n    \"\"\"\n\ndef squeeze(\n    a: array,\n    /,\n    axis: int | Sequence[int] | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Remove length one axes from an array.\n\n    Args:\n        a (array): Input array.\n        axis (int or tuple(int), optional): Axes to remove. Defaults\n          to ``None`` in which case all size one axes are removed.\n\n    Returns:\n        array: The output array with size one axes removed.\n    \"\"\"\n\ndef stack(\n    arrays: list[array],\n    axis: int | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Stacks the arrays along a new axis.\n\n    Args:\n        arrays (list(array)): A list of arrays to stack.\n        axis (int, optional): The axis in the result array along which the\n          input arrays are stacked. Defaults to ``0``.\n        stream (Stream, optional): Stream or device. Defaults to ``None``.\n\n    Returns:\n        array: The resulting stacked array.\n    \"\"\"\n\ndef std(\n    a: array,\n    /,\n    axis: int | Sequence[int] | None = ...,\n    keepdims: bool = ...,\n    ddof: int = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Compute the standard deviation(s) over the given axes.\n\n    Args:\n        a (array): Input array.\n        axis (int or list(int), optional): Optional axis or\n          axes to reduce over. If unspecified this defaults\n          to reducing over the entire array.\n        keepdims (bool, optional): Keep reduced axes as\n          singleton dimensions, defaults to `False`.\n        ddof (int, optional): The divisor to compute the variance\n          is ``N - ddof``, defaults to 0.\n\n    Returns:\n        array: The output array of standard deviations.\n    \"\"\"\n\ndef stop_gradient(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Stop gradients from being computed.\n\n    The operation is the identity but it prevents gradients from flowing\n    through the array.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array:\n          The unchanged input ``a`` but without gradient flowing\n          through it.\n    \"\"\"\n\ndef stream(s: Stream | Device) -> StreamContext:\n    \"\"\"\n    Create a context manager to set the default device and stream.\n\n    Args:\n        s: The :obj:`Stream` or :obj:`Device` to set as the default.\n\n    Returns:\n        A context manager that sets the default device and stream.\n\n    Example:\n\n    .. code-block::python\n\n      import mlx.core as mx\n\n      # Create a context manager for the default device and stream.\n      with mx.stream(mx.cpu):\n          # Operations here will use mx.cpu by default.\n          pass\n    \"\"\"\n\ndef subtract(\n    a: scalar | array,\n    b: scalar | array,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Element-wise subtraction.\n\n    Subtract one array from another with numpy-style broadcasting semantics. Either or both\n    input arrays can also be scalars.\n\n    Args:\n        a (array): Input array or scalar.\n        b (array): Input array or scalar.\n\n    Returns:\n        array: The difference ``a - b``.\n    \"\"\"\n\ndef sum(\n    a: array,\n    /,\n    axis: int | Sequence[int] | None = ...,\n    keepdims: bool = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Sum reduce the array over the given axes.\n\n    Args:\n        a (array): Input array.\n        axis (int or list(int), optional): Optional axis or\n          axes to reduce over. If unspecified this defaults\n          to reducing over the entire array.\n        keepdims (bool, optional): Keep reduced axes as\n          singleton dimensions, defaults to `False`.\n\n    Returns:\n        array: The output array with the corresponding axes reduced.\n    \"\"\"\n\ndef swapaxes(\n    a: array, /, axis1: int, axis2: int, *, stream: Stream | Device | None = ...\n) -> array:\n    \"\"\"\n    Swap two axes of an array.\n\n    Args:\n        a (array): Input array.\n        axis1 (int): Specifies the first axis.\n        axis2 (int): Specifies the second axis.\n\n    Returns:\n        array: The array with swapped axes.\n    \"\"\"\n\ndef synchronize(stream: Stream | None = ...) -> None:\n    \"\"\"\n    Synchronize with the given stream.\n\n    Args:\n      stream (Stream, optional): The stream to synchronize with. If ``None``\n         then the default stream of the default device is used.\n         Default: ``None``.\n    \"\"\"\n\ndef take(\n    a: array,\n    /,\n    indices: int | array,\n    axis: int | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Take elements along an axis.\n\n    The elements are taken from ``indices`` along the specified axis.\n    If the axis is not specified the array is treated as a flattened\n    1-D array prior to performing the take.\n\n    As an example, if the ``axis=1`` this is equivalent to ``a[:, indices, ...]``.\n\n    Args:\n        a (array): Input array.\n        indices (int or array): Integer index or input array with integral type.\n        axis (int, optional): Axis along which to perform the take. If unspecified\n          the array is treated as a flattened 1-D vector.\n\n    Returns:\n        array: The indexed values of ``a``.\n    \"\"\"\n\ndef take_along_axis(\n    a: array,\n    /,\n    indices: array,\n    axis: int | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Take values along an axis at the specified indices.\n\n    Args:\n        a (array): Input array.\n        indices (array): Indices array. These should be broadcastable with\n          the input array excluding the `axis` dimension.\n        axis (int or None): Axis in the input to take the values from. If\n          ``axis == None`` the array is flattened to 1D prior to the indexing\n          operation.\n\n    Returns:\n        array: The output array.\n    \"\"\"\n\ndef tan(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise tangent.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The tangent of ``a``.\n    \"\"\"\n\ndef tanh(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Element-wise hyperbolic tangent.\n\n    Args:\n        a (array): Input array.\n\n    Returns:\n        array: The hyperbolic tangent of ``a``.\n    \"\"\"\n\ndef tensordot(\n    a: array,\n    b: array,\n    /,\n    axes: int | list[Sequence[int]] = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Compute the tensor dot product along the specified axes.\n\n    Args:\n        a (array): Input array\n        b (array): Input array\n        axes (int or list(list(int)), optional): The number of dimensions to\n          sum over. If an integer is provided, then sum over the last\n          ``axes`` dimensions of ``a`` and the first ``axes`` dimensions of\n          ``b``. If a list of lists is provided, then sum over the\n          corresponding dimensions of ``a`` and ``b``. Default: 2.\n\n    Returns:\n        array: The tensor dot product.\n    \"\"\"\n\ndef tile(\n    a: array,\n    reps: int | Sequence[int],\n    /,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Construct an array by repeating ``a`` the number of times given by ``reps``.\n\n    Args:\n      a (array): Input array\n      reps (int or list(int)): The number of times to repeat ``a`` along each axis.\n\n    Returns:\n      array: The tiled array.\n    \"\"\"\n\ndef topk(\n    a: array,\n    /,\n    k: int,\n    axis: int | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Returns the ``k`` largest elements from the input along a given axis.\n\n    The elements will not necessarily be in sorted order.\n\n    Args:\n        a (array): Input array.\n        k (int): ``k`` top elements to be returned\n        axis (int or None, optional): Optional axis to select over.\n          If ``None``, this selects the top ``k`` elements over the\n          flattened array. If unspecified, it defaults to ``-1``.\n\n    Returns:\n        array: The top ``k`` elements from the input.\n    \"\"\"\n\ndef trace(\n    a: array,\n    /,\n    offset: int = ...,\n    axis1: int = ...,\n    axis2: int = ...,\n    dtype: Dtype | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Return the sum along a specified diagonal in the given array.\n\n    Args:\n      a (array): Input array\n      offset (int, optional): Offset of the diagonal from the main diagonal.\n        Can be positive or negative. Default: ``0``.\n      axis1 (int, optional): The first axis of the 2-D sub-arrays from which\n          the diagonals should be taken. Default: ``0``.\n      axis2 (int, optional): The second axis of the 2-D sub-arrays from which\n          the diagonals should be taken. Default: ``1``.\n      dtype (Dtype, optional): Data type of the output array. If\n          unspecified the output type is inferred from the input array.\n\n    Returns:\n        array: Sum of specified diagonal.\n    \"\"\"\n\ndef transpose(\n    a: array,\n    /,\n    axes: Sequence[int] | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Transpose the dimensions of the array.\n\n    Args:\n        a (array): Input array.\n        axes (list(int), optional): Specifies the source axis for each axis\n          in the new array. The default is to reverse the axes.\n\n    Returns:\n        array: The transposed array.\n    \"\"\"\n\ndef tri(\n    n: int,\n    m: int,\n    k: int,\n    dtype: Dtype | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    An array with ones at and below the given diagonal and zeros elsewhere.\n\n    Args:\n      n (int): The number of rows in the output.\n      m (int, optional): The number of cols in the output. Defaults to ``None``.\n      k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.\n      dtype (Dtype, optional): Data type of the output array. Defaults to ``float32``.\n      stream (Stream, optional): Stream or device. Defaults to ``None``.\n\n    Returns:\n      array: Array with its lower triangle filled with ones and zeros elsewhere\n    \"\"\"\n\ndef tril(x: array, k: int, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Zeros the array above the given diagonal.\n\n    Args:\n      x (array): input array.\n      k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.\n      stream (Stream, optional): Stream or device. Defaults to ``None``.\n\n    Returns:\n      array: Array zeroed above the given diagonal\n    \"\"\"\n\ndef triu(x: array, k: int, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Zeros the array below the given diagonal.\n\n    Args:\n      x (array): input array.\n      k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.\n      stream (Stream, optional): Stream or device. Defaults to ``None``.\n\n    Returns:\n      array: Array zeroed below the given diagonal\n    \"\"\"\n\nuint16: Dtype = ...\nuint32: Dtype = ...\nuint64: Dtype = ...\nuint8: Dtype = ...\n\ndef unflatten(\n    a: array,\n    /,\n    axis: int,\n    shape: Sequence[int],\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Unflatten an axis of an array to a shape.\n\n    Args:\n        a (array): Input array.\n        axis (int): The axis to unflatten.\n        shape (tuple(int)): The shape to unflatten to. At most one\n          entry can be ``-1`` in which case the corresponding size will be\n          inferred.\n        stream (Stream, optional): Stream or device. Defaults to ``None``\n          in which case the default stream of the default device is used.\n\n    Returns:\n        array: The unflattened array.\n\n    Example:\n        >>> a = mx.array([1, 2, 3, 4])\n        >>> mx.unflatten(a, 0, (2, -1))\n        array([[1, 2], [3, 4]], dtype=int32)\n    \"\"\"\n\nunsignedinteger: DtypeCategory = ...\n\ndef value_and_grad(\n    fun: Callable,\n    argnums: int | Sequence[int] | None = ...,\n    argnames: str | Sequence[str] = ...,\n) -> Callable:\n    \"\"\"\n    Returns a function which computes the value and gradient of ``fun``.\n\n    The function passed to :func:`value_and_grad` should return either\n    a scalar loss or a tuple in which the first element is a scalar\n    loss and the remaining elements can be anything.\n\n    .. code-block:: python\n\n        import mlx.core as mx\n\n        def mse(params, inputs, targets):\n            outputs = forward(params, inputs)\n            lvalue = (outputs - targets).square().mean()\n            return lvalue\n\n        # Returns lvalue, dlvalue/dparams\n        lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets)\n\n        def lasso(params, inputs, targets, a=1.0, b=1.0):\n            outputs = forward(params, inputs)\n            mse = (outputs - targets).square().mean()\n            l1 = mx.abs(outputs - targets).mean()\n\n            loss = a*mse + b*l1\n\n            return loss, mse, l1\n\n        (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)\n\n    Args:\n        fun (Callable): A function which takes a variable number of\n          :class:`array` or trees of :class:`array` and returns\n          a scalar output :class:`array` or a tuple the first element\n          of which should be a scalar :class:`array`.\n        argnums (int or list(int), optional): Specify the index (or indices)\n          of the positional arguments of ``fun`` to compute the gradient\n          with respect to. If neither ``argnums`` nor ``argnames`` are\n          provided ``argnums`` defaults to ``0`` indicating ``fun``'s first\n          argument.\n        argnames (str or list(str), optional): Specify keyword arguments of\n          ``fun`` to compute gradients with respect to. It defaults to [] so\n          no gradients for keyword arguments by default.\n\n    Returns:\n        Callable: A function which returns a tuple where the first element\n        is the output of `fun` and the second element is the gradients w.r.t.\n        the loss.\n    \"\"\"\n\ndef var(\n    a: array,\n    /,\n    axis: int | Sequence[int] | None = ...,\n    keepdims: bool = ...,\n    ddof: int = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Compute the variance(s) over the given axes.\n\n    Args:\n        a (array): Input array.\n        axis (int or list(int), optional): Optional axis or\n          axes to reduce over. If unspecified this defaults\n          to reducing over the entire array.\n        keepdims (bool, optional): Keep reduced axes as\n          singleton dimensions, defaults to `False`.\n        ddof (int, optional): The divisor to compute the variance\n          is ``N - ddof``, defaults to 0.\n\n    Returns:\n        array: The output array of variances.\n    \"\"\"\n\ndef view(\n    a: scalar | array, dtype: Dtype, stream: Stream | Device | None = ...\n) -> array:\n    \"\"\"\n    View the array as a different type.\n\n    The output shape changes along the last axis if the input array's\n    type and the input ``dtype`` do not have the same size.\n\n    Note: the view op does not imply that the input and output arrays share\n    their underlying data. The view only gaurantees that the binary\n    representation of each element (or group of elements) is the same.\n\n    Args:\n        a (array): Input array or scalar.\n        dtype (Dtype): The data type to change to.\n\n    Returns:\n        array: The array with the new type.\n    \"\"\"\n\ndef vjp(\n    fun: Callable, primals: list[array], cotangents: list[array]\n) -> tuple[list[array], list[array]]:\n    \"\"\"\n    Compute the vector-Jacobian product.\n\n    Computes the product of the ``cotangents`` with the Jacobian of a\n    function ``fun`` evaluated at ``primals``.\n\n    Args:\n      fun (Callable): A function which takes a variable number of :class:`array`\n        and returns a single :class:`array` or list of :class:`array`.\n      primals (list(array)): A list of :class:`array` at which to\n        evaluate the Jacobian.\n      cotangents (list(array)): A list of :class:`array` which are the\n        \"vector\" in the vector-Jacobian product. The ``cotangents`` should be the\n        same in number, shape, and type as the outputs of ``fun``.\n\n    Returns:\n        list(array): A list of the vector-Jacobian products which\n        is the same in number, shape, and type of the outputs of ``fun``.\n    \"\"\"\n\ndef vmap(fun: Callable, in_axes: object = ..., out_axes: object = ...) -> Callable:\n    \"\"\"\n    Returns a vectorized version of ``fun``.\n\n    Args:\n        fun (Callable): A function which takes a variable number of\n          :class:`array` or a tree of :class:`array` and returns\n          a variable number of :class:`array` or a tree of :class:`array`.\n        in_axes (int, optional): An integer or a valid prefix tree of the\n          inputs to ``fun`` where each node specifies the vmapped axis. If\n          the value is ``None`` then the corresponding input(s) are not vmapped.\n          Defaults to ``0``.\n        out_axes (int, optional): An integer or a valid prefix tree of the\n          outputs of ``fun`` where each node specifies the vmapped axis. If\n          the value is ``None`` then the corresponding outputs(s) are not vmapped.\n          Defaults to ``0``.\n\n    Returns:\n        Callable: The vectorized function.\n    \"\"\"\n\ndef where(\n    condition: scalar | array,\n    x: scalar | array,\n    y: scalar | array,\n    /,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Select from ``x`` or ``y`` according to ``condition``.\n\n    The condition and input arrays must be the same shape or\n    broadcastable with each another.\n\n    Args:\n      condition (array): The condition array.\n      x (array): The input selected from where condition is ``True``.\n      y (array): The input selected from where condition is ``False``.\n\n    Returns:\n        array: The output containing elements selected from\n        ``x`` and ``y``.\n    \"\"\"\n\ndef zeros(\n    shape: int | Sequence[int],\n    dtype: Dtype | None = ...,\n    *,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Construct an array of zeros.\n\n    Args:\n        shape (int or list(int)): The shape of the output array.\n        dtype (Dtype, optional): Data type of the output array. If\n          unspecified the output type defaults to ``float32``.\n\n    Returns:\n        array: The array of zeros with the specified shape.\n    \"\"\"\n\ndef zeros_like(a: array, /, *, stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    An array of zeros like the input.\n\n    Args:\n        a (array): The input to take the shape and type from.\n\n    Returns:\n        array: The output array filled with zeros.\n    \"\"\"\n\nscalar: TypeAlias = int | float | bool\nlist_or_scalar: TypeAlias = scalar | list[\"list_or_scalar\"]\nbool_: Dtype = ...\n"
  },
  {
    "path": ".mlx_typings/mlx/core/cuda/__init__.pyi",
    "content": "def is_available() -> bool:\n    \"\"\"Check if the CUDA back-end is available.\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx/core/distributed/__init__.pyi",
    "content": "from typing import Sequence\n\nfrom mlx.core import Device, Dtype, Stream, array\n\nclass Group:\n    \"\"\"\n    An :class:`mlx.core.distributed.Group` represents a group of independent mlx\n    processes that can communicate.\n    \"\"\"\n    def rank(self) -> int:\n        \"\"\"Get the rank of this process\"\"\"\n\n    def size(self) -> int:\n        \"\"\"Get the size of the group\"\"\"\n\n    def split(self, color: int, key: int = ...) -> Group:\n        \"\"\"\n        Split the group to subgroups based on the provided color.\n\n        Processes that use the same color go to the same group. The ``key``\n        argument defines the rank in the new group. The smaller the key the\n        smaller the rank. If the key is negative then the rank in the\n        current group is used.\n\n        Args:\n          color (int): A value to group processes into subgroups.\n          key (int, optional): A key to optionally change the rank ordering\n            of the processes.\n        \"\"\"\n\ndef all_gather(\n    x: array, *, group: Group | None = ..., stream: Stream | Device | None = ...\n) -> array:\n    \"\"\"\n    Gather arrays from all processes.\n\n    Gather the ``x`` arrays from all processes in the group and concatenate\n    them along the first axis. The arrays should all have the same shape.\n\n    Args:\n      x (array): Input array.\n      group (Group): The group of processes that will participate in the\n        gather. If set to ``None`` the global group is used. Default:\n        ``None``.\n      stream (Stream, optional): Stream or device. Defaults to ``None``\n        in which case the default stream of the default device is used.\n\n    Returns:\n      array: The concatenation of all ``x`` arrays.\n    \"\"\"\n\ndef all_max(\n    x: array, *, group: Group | None = ..., stream: Stream | Device | None = ...\n) -> array:\n    \"\"\"\n    All reduce max.\n\n    Find the maximum of the ``x`` arrays from all processes in the group.\n\n    Args:\n      x (array): Input array.\n      group (Group): The group of processes that will participate in the\n        reduction. If set to ``None`` the global group is used. Default:\n        ``None``.\n      stream (Stream, optional): Stream or device. Defaults to ``None``\n        in which case the default stream of the default device is used.\n\n    Returns:\n      array: The maximum of all ``x`` arrays.\n    \"\"\"\n\ndef all_min(\n    x: array, *, group: Group | None = ..., stream: Stream | Device | None = ...\n) -> array:\n    \"\"\"\n    All reduce min.\n\n    Find the minimum of the ``x`` arrays from all processes in the group.\n\n    Args:\n      x (array): Input array.\n      group (Group): The group of processes that will participate in the\n        reduction. If set to ``None`` the global group is used. Default:\n        ``None``.\n      stream (Stream, optional): Stream or device. Defaults to ``None``\n        in which case the default stream of the default device is used.\n\n    Returns:\n      array: The minimum of all ``x`` arrays.\n    \"\"\"\n\ndef all_sum(\n    x: array, *, group: Group | None = ..., stream: Stream | Device | None = ...\n) -> array:\n    \"\"\"\n    All reduce sum.\n\n    Sum the ``x`` arrays from all processes in the group.\n\n    Args:\n      x (array): Input array.\n      group (Group): The group of processes that will participate in the\n        reduction. If set to ``None`` the global group is used. Default:\n        ``None``.\n      stream (Stream, optional): Stream or device. Defaults to ``None``\n        in which case the default stream of the default device is used.\n\n    Returns:\n      array: The sum of all ``x`` arrays.\n    \"\"\"\n\ndef init(strict: bool = ..., backend: str = ...) -> Group:\n    \"\"\"\n    Initialize the communication backend and create the global communication group.\n\n    Example:\n\n      .. code:: python\n\n        import mlx.core as mx\n\n        group = mx.distributed.init(backend=\"ring\")\n\n    Args:\n      strict (bool, optional): If set to False it returns a singleton group\n        in case ``mx.distributed.is_available()`` returns False otherwise\n        it throws a runtime error. Default: ``False``\n      backend (str, optional): Which distributed backend to initialize.\n        Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all\n        available backends are tried and the first one that succeeds\n        becomes the global group which will be returned in subsequent\n        calls. Default: ``any``\n\n    Returns:\n      Group: The group representing all the launched processes.\n    \"\"\"\n\ndef is_available() -> bool:\n    \"\"\"Check if a communication backend is available.\"\"\"\n\ndef recv(\n    shape: Sequence[int],\n    dtype: Dtype,\n    src: int,\n    *,\n    group: Group | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Recv an array with shape ``shape`` and dtype ``dtype`` from process\n    with rank ``src``.\n\n    Args:\n      shape (tuple[int]): The shape of the array we are receiving.\n      dtype (Dtype): The data type of the array we are receiving.\n      src (int): Rank of the source process in the group.\n      group (Group): The group of processes that will participate in the\n        recv. If set to ``None`` the global group is used. Default:\n        ``None``.\n      stream (Stream, optional): Stream or device. Defaults to ``None``\n        in which case the default stream of the default device is used.\n\n    Returns:\n      array: The array that was received from ``src``.\n    \"\"\"\n\ndef recv_like(\n    x: array,\n    src: int,\n    *,\n    group: Group | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Recv an array with shape and type like ``x`` from process with rank\n    ``src``.\n\n    It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``.\n\n    Args:\n      x (array): An array defining the shape and dtype of the array we are\n        receiving.\n      src (int): Rank of the source process in the group.\n      group (Group): The group of processes that will participate in the\n        recv. If set to ``None`` the global group is used. Default:\n        ``None``.\n      stream (Stream, optional): Stream or device. Defaults to ``None``\n        in which case the default stream of the default device is used.\n\n    Returns:\n      array: The array that was received from ``src``.\n    \"\"\"\n\ndef send(\n    x: array,\n    dst: int,\n    *,\n    group: Group | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Send an array from the current process to the process that has rank\n    ``dst`` in the group.\n\n    Args:\n      x (array): Input array.\n      dst (int): Rank of the destination process in the group.\n      group (Group): The group of processes that will participate in the\n        sned. If set to ``None`` the global group is used. Default:\n        ``None``.\n      stream (Stream, optional): Stream or device. Defaults to ``None``\n        in which case the default stream of the default device is used.\n\n    Returns:\n      array: An array identical to ``x`` which when evaluated the send is performed.\n    \"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx/core/metal/__init__.pyi",
    "content": "def clear_cache() -> None: ...\ndef device_info() -> dict[str, str | int]:\n    \"\"\"\n    Get information about the GPU device and system settings.\n\n    Currently returns:\n\n    * ``architecture``\n    * ``max_buffer_size``\n    * ``max_recommended_working_set_size``\n    * ``memory_size``\n    * ``resource_limit``\n\n    Returns:\n        dict: A dictionary with string keys and string or integer values.\n    \"\"\"\n\ndef get_active_memory() -> int: ...\ndef get_cache_memory() -> int: ...\ndef get_peak_memory() -> int: ...\ndef is_available() -> bool:\n    \"\"\"Check if the Metal back-end is available.\"\"\"\n\ndef reset_peak_memory() -> None: ...\ndef set_cache_limit(limit: int) -> int: ...\ndef set_memory_limit(limit: int) -> int: ...\ndef set_wired_limit(limit: int) -> int: ...\ndef start_capture(path: str) -> None:\n    \"\"\"\n    Start a Metal capture.\n\n    Args:\n      path (str): The path to save the capture which should have\n        the extension ``.gputrace``.\n    \"\"\"\n\ndef stop_capture() -> None:\n    \"\"\"Stop a Metal capture.\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx/core/random/__init__.pyi",
    "content": "from typing import Sequence\n\nfrom mlx.core import Device, Dtype, Stream, array, scalar\nfrom mlx.core.distributed import state as state\n\ndef bernoulli(\n    p: scalar | array = ...,\n    shape: Sequence[int] | None = ...,\n    key: array | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Generate Bernoulli random values.\n\n    The values are sampled from the bernoulli distribution with parameter\n    ``p``. The parameter ``p`` can be a :obj:`float` or :obj:`array` and\n    must be broadcastable to ``shape``.\n\n    Args:\n        p (float or array, optional): Parameter of the Bernoulli\n          distribution. Default: ``0.5``.\n        shape (list(int), optional): Shape of the output.\n          Default: ``p.shape``.\n        key (array, optional): A PRNG key. Default: ``None``.\n\n    Returns:\n        array: The array of random integers.\n    \"\"\"\n\ndef categorical(\n    logits: array,\n    axis: int = ...,\n    shape: Sequence[int] | None = ...,\n    num_samples: int | None = ...,\n    key: array | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Sample from a categorical distribution.\n\n    The values are sampled from the categorical distribution specified by\n    the unnormalized values in ``logits``. Note, at most one of ``shape``\n    or ``num_samples`` can be specified. If both are ``None``, the output\n    has the same shape as ``logits`` with the ``axis`` dimension removed.\n\n    Args:\n        logits (array): The *unnormalized* categorical distribution(s).\n        axis (int, optional): The axis which specifies the distribution.\n           Default: ``-1``.\n        shape (list(int), optional): The shape of the output. This must\n           be broadcast compatible with ``logits.shape`` with the ``axis``\n           dimension removed. Default: ``None``\n        num_samples (int, optional): The number of samples to draw from each\n          of the categorical distributions in ``logits``. The output will have\n          ``num_samples`` in the last dimension. Default: ``None``.\n        key (array, optional): A PRNG key. Default: ``None``.\n\n    Returns:\n        array: The ``shape``-sized output array with type ``uint32``.\n    \"\"\"\n\ndef gumbel(\n    shape: Sequence[int] = ...,\n    dtype: Dtype | None = ...,\n    key: Stream | Device | None = ...,\n    stream: array | None = ...,\n) -> array:\n    \"\"\"\n    Sample from the standard Gumbel distribution.\n\n    The values are sampled from a standard Gumbel distribution\n    which CDF ``exp(-exp(-x))``.\n\n    Args:\n        shape (list(int)): The shape of the output.\n        dtype (Dtype, optional): The data type of the output.\n          Default: ``float32``.\n        key (array, optional): A PRNG key. Default: ``None``.\n\n    Returns:\n        array:\n          The :class:`array` with shape ``shape`` and distributed according\n          to the Gumbel distribution.\n    \"\"\"\n\ndef key(seed: int) -> array:\n    \"\"\"\n    Get a PRNG key from a seed.\n\n    Args:\n        seed (int): Seed for the PRNG.\n\n    Returns:\n        array: The PRNG key array.\n    \"\"\"\n\ndef laplace(\n    shape: Sequence[int] = ...,\n    dtype: Dtype | None = ...,\n    loc: float = ...,\n    scale: float = ...,\n    key: array | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Sample numbers from a Laplace distribution.\n\n    Args:\n        shape (list(int), optional): Shape of the output. Default: ``()``.\n        dtype (Dtype, optional): Type of the output. Default: ``float32``.\n        loc (float, optional): Mean of the distribution. Default: ``0.0``.\n        scale (float, optional): The scale \"b\" of the Laplace distribution.\n          Default:``1.0``.\n        key (array, optional): A PRNG key. Default: ``None``.\n\n    Returns:\n        array: The output array of random values.\n    \"\"\"\n\ndef multivariate_normal(\n    mean: array,\n    cov: array,\n    shape: Sequence[int] = ...,\n    dtype: Dtype | None = ...,\n    key: array | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Generate jointly-normal random samples given a mean and covariance.\n\n    The matrix ``cov`` must be positive semi-definite. The behavior is\n    undefined if it is not.  The only supported ``dtype`` is ``float32``.\n\n    Args:\n        mean (array): array of shape ``(..., n)``, the mean of the\n          distribution.\n        cov (array): array  of shape ``(..., n, n)``, the covariance\n          matrix of the distribution. The batch shape ``...`` must be\n          broadcast-compatible with that of ``mean``.\n        shape (list(int), optional): The output shape must be\n          broadcast-compatible with ``mean.shape[:-1]`` and ``cov.shape[:-2]``.\n          If empty, the result shape is determined by broadcasting the batch\n          shapes of ``mean`` and ``cov``. Default: ``[]``.\n        dtype (Dtype, optional): The output type. Default: ``float32``.\n        key (array, optional): A PRNG key. Default: ``None``.\n\n    Returns:\n        array: The output array of random values.\n    \"\"\"\n\ndef normal(\n    shape: Sequence[int] = ...,\n    dtype: Dtype | None = ...,\n    loc: scalar | array | None = ...,\n    scale: scalar | array | None = ...,\n    key: array | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    r\"\"\"\n    Generate normally distributed random numbers.\n\n    If ``loc`` and ``scale`` are not provided the \"standard\" normal\n    distribution is used. That means $x \\sim \\mathcal{N}(0, 1)$ for\n    real numbers and $\\text{Re}(x),\\text{Im}(x) \\sim \\mathcal{N}(0,\n    \\frac{1}{2})$ for complex numbers.\n\n    Args:\n        shape (list(int), optional): Shape of the output. Default: ``()``.\n        dtype (Dtype, optional): Type of the output. Default: ``float32``.\n        loc (scalar or array, optional): Mean of the distribution.\n          Default: ``None``.\n        scale (scalar or array, optional): Standard deviation of the\n          distribution. Default: ``None``.\n        key (array, optional): A PRNG key. Default: ``None``.\n\n    Returns:\n        array: The output array of random values.\n    \"\"\"\n\ndef permutation(\n    x: int | array,\n    axis: int = ...,\n    key: array | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Generate a random permutation or permute the entries of an array.\n\n    Args:\n        x (int or array, optional): If an integer is provided a random\n          permtuation of ``mx.arange(x)`` is returned. Otherwise the entries\n          of ``x`` along the given axis are randomly permuted.\n        axis (int, optional): The axis to permute along. Default: ``0``.\n        key (array, optional): A PRNG key. Default: ``None``.\n\n    Returns:\n        array:\n          The generated random permutation or randomly permuted input array.\n    \"\"\"\n\ndef randint(\n    low: scalar | array,\n    high: scalar | array,\n    shape: Sequence[int] = ...,\n    dtype: Dtype | None = ...,\n    key: array | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Generate random integers from the given interval.\n\n    The values are sampled with equal probability from the integers in\n    half-open interval ``[low, high)``. The lower and upper bound can be\n    scalars or arrays and must be broadcastable to ``shape``.\n\n    Args:\n        low (scalar or array): Lower bound of the interval.\n        high (scalar or array): Upper bound of the interval.\n        shape (list(int), optional): Shape of the output. Default: ``()``.\n        dtype (Dtype, optional): Type of the output. Default: ``int32``.\n        key (array, optional): A PRNG key. Default: ``None``.\n\n    Returns:\n        array: The array of random integers.\n    \"\"\"\n\ndef seed(seed: int) -> None:\n    \"\"\"\n    Seed the global PRNG.\n\n    Args:\n        seed (int): Seed for the global PRNG.\n    \"\"\"\n\ndef split(key: array, num: int = ..., stream: Stream | Device | None = ...) -> array:\n    \"\"\"\n    Split a PRNG key into sub keys.\n\n    Args:\n        key (array): Input key to split.\n        num (int, optional): Number of sub keys. Default: ``2``.\n\n    Returns:\n        array: The array of sub keys with ``num`` as its first dimension.\n    \"\"\"\n\ndef truncated_normal(\n    lower: scalar | array,\n    upper: scalar | array,\n    shape: Sequence[int] | None = ...,\n    dtype: Dtype | None = ...,\n    key: array | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Generate values from a truncated normal distribution.\n\n    The values are sampled from the truncated normal distribution\n    on the domain ``(lower, upper)``. The bounds ``lower`` and ``upper``\n    can be scalars or arrays and must be broadcastable to ``shape``.\n\n    Args:\n        lower (scalar or array): Lower bound of the domain.\n        upper (scalar or array): Upper bound of the domain.\n        shape (list(int), optional): The shape of the output.\n          Default:``()``.\n        dtype (Dtype, optional): The data type of the output.\n          Default: ``float32``.\n        key (array, optional): A PRNG key. Default: ``None``.\n\n    Returns:\n        array: The output array of random values.\n    \"\"\"\n\ndef uniform(\n    low: scalar | array = ...,\n    high: scalar | array = ...,\n    shape: Sequence[int] = ...,\n    dtype: Dtype | None = ...,\n    key: array | None = ...,\n    stream: Stream | Device | None = ...,\n) -> array:\n    \"\"\"\n    Generate uniformly distributed random numbers.\n\n    The values are sampled uniformly in the half-open interval ``[low, high)``.\n    The lower and upper bound can be scalars or arrays and must be\n    broadcastable to ``shape``.\n\n    Args:\n        low (scalar or array, optional): Lower bound of the distribution.\n          Default: ``0``.\n        high (scalar or array, optional): Upper bound of the distribution.\n          Default: ``1``.\n        shape (list(int), optional): Shape of the output. Default:``()``.\n        dtype (Dtype, optional): Type of the output. Default: ``float32``.\n        key (array, optional): A PRNG key. Default: ``None``.\n\n    Returns:\n        array: The output array random values.\n    \"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom layers import *\nfrom utils import *\n\nfrom . import init as init\nfrom . import losses as losses\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/init.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Callable, Literal\n\nimport mlx.core as mx\n\ndef constant(value: float, dtype: mx.Dtype = ...) -> Callable[[mx.array], mx.array]:\n    r\"\"\"An initializer that returns an array filled with ``value``.\n\n    Args:\n        value (float): The value to fill the array with.\n        dtype (Dtype, optional): The data type of the array. Default:\n          ``float32``.\n\n    Returns:\n        Callable[[array], array]: An initializer that returns an array with the\n        same shape as the input, filled with ``value``.\n\n    Example:\n\n        >>> init_fn = nn.init.constant(0.5)\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[0.5, 0.5],\n               [0.5, 0.5]], dtype=float32)\n    \"\"\"\n\ndef normal(\n    mean: float = ..., std: float = ..., dtype: mx.Dtype = ...\n) -> Callable[[mx.array], mx.array]:\n    r\"\"\"An initializer that returns samples from a normal distribution.\n\n    Args:\n        mean (float, optional): Mean of the normal distribution. Default:\n          ``0.0``.\n        std (float, optional): Standard deviation of the normal distribution.\n          Default: ``1.0``.\n        dtype (Dtype, optional): The data type of the array. Default:\n          ``float32``.\n\n    Returns:\n        Callable[[array], array]: An initializer that returns an array with the\n        same shape as the input, filled with samples from a normal distribution.\n\n    Example:\n\n        >>> init_fn = nn.init.normal()\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[-0.982273, -0.534422],\n               [0.380709, 0.0645099]], dtype=float32)\n    \"\"\"\n\ndef uniform(\n    low: float = ..., high: float = ..., dtype: mx.Dtype = ...\n) -> Callable[[mx.array], mx.array]:\n    r\"\"\"An initializer that returns samples from a uniform distribution.\n\n    Args:\n        low (float, optional): The lower bound of the uniform distribution.\n          Default: ``0.0``.\n        high (float, optional): The upper bound of the uniform distribution.\n          Default: ``1.0``\n        dtype (Dtype, optional): The data type of the array. Default: ``float32``.\n\n    Returns:\n        Callable[[array], array]: An initializer that returns an array\n        with the same shape as the input, filled with samples from a uniform\n        distribution\n\n    Example:\n\n        >>> init_fn = nn.init.uniform(low=0, high=1)\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[0.883935, 0.863726],\n               [0.617261, 0.417497]], dtype=float32)\n    \"\"\"\n\ndef identity(dtype: mx.Dtype = ...) -> Callable[[mx.array], mx.array]:\n    r\"\"\"An initializer that returns an identity matrix.\n\n    Args:\n        dtype (Dtype, optional): The data type of the array. Defaults:\n          ``float32``.\n\n    Returns:\n        Callable[[array], array]: An initializer that returns an identity\n        matrix with the same shape as the input.\n\n    Example:\n\n        >>> init_fn = nn.init.identity()\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[1, 0],\n               [0, 1]], dtype=float32)\n    \"\"\"\n\ndef glorot_normal(dtype: mx.Dtype = ...) -> Callable[[mx.array, float], mx.array]:\n    r\"\"\"A Glorot normal initializer.\n\n    This initializer samples from a normal distribution with a standard\n    deviation computed from the number of input (``fan_in``) and output\n    (``fan_out``) units according to:\n\n    .. math::\n        \\sigma = \\gamma \\sqrt{\\frac{2.0}{\\text{fan\\_in} + \\text{fan\\_out}}}\n\n    For more details see the original reference: `Understanding the difficulty\n    of training deep feedforward neural networks\n    <https://proceedings.mlr.press/v9/glorot10a.html>`_\n\n    Args:\n        dtype (Dtype, optional): The data type of the array. Default: ``float32``.\n\n    Returns:\n        Callable[[array, float], array]: An initializer that returns an array\n        with the same shape as the input, filled with samples from the Glorot\n        normal distribution.\n\n    Example:\n\n        >>> init_fn = nn.init.glorot_normal()\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[0.191107, 1.61278],\n               [-0.150594, -0.363207]], dtype=float32)\n        >>> init_fn(mx.zeros((2, 2)), gain=4.0)\n        array([[1.89613, -4.53947],\n               [4.48095, 0.995016]], dtype=float32)\n    \"\"\"\n\ndef glorot_uniform(dtype: mx.Dtype = ...) -> Callable[[mx.array, float], mx.array]:\n    r\"\"\"A Glorot uniform initializer.\n\n    This initializer samples from a uniform distribution with a range\n    computed from the number of input (``fan_in``) and output (``fan_out``)\n    units according to:\n\n    .. math::\n        \\sigma = \\gamma \\sqrt{\\frac{6.0}{\\text{fan\\_in} + \\text{fan\\_out}}}\n\n    For more details see the original reference: `Understanding the difficulty\n    of training deep feedforward neural networks\n    <https://proceedings.mlr.press/v9/glorot10a.html>`_\n\n    Args:\n        dtype (Dtype, optional): The data type of the array. Default: ``float32``.\n\n    Returns:\n        Callable[[array, float], array]: An initializer that returns an array\n        with the same shape as the input, filled with samples from the Glorot\n        uniform distribution.\n\n    Example:\n\n        >>> init_fn = nn.init.glorot_uniform()\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[0.223404, -0.890597],\n               [-0.379159, -0.776856]], dtype=float32)\n        >>> init_fn(mx.zeros((2, 2)), gain=4.0)\n        array([[-1.90041, 3.02264],\n               [-0.912766, 4.12451]], dtype=float32)\n    \"\"\"\n\ndef he_normal(\n    dtype: mx.Dtype = ...,\n) -> Callable[[mx.array, Literal[\"fan_in\", \"fan_out\"], float], mx.array]:\n    r\"\"\"Build a He normal initializer.\n\n    This initializer samples from a normal distribution with a standard\n    deviation computed from the number of input (``fan_in``) or output\n    (``fan_out``) units according to:\n\n    .. math::\n        \\sigma = \\gamma \\frac{1}{\\sqrt{\\text{fan}}}\n\n    where :math:`\\text{fan}` is either the number of input units when the\n    ``mode`` is ``\"fan_in\"`` or output units when the ``mode`` is\n    ``\"fan_out\"``.\n\n    For more details see the original reference: `Delving Deep into Rectifiers:\n    Surpassing Human-Level Performance on ImageNet Classification\n    <https://arxiv.org/abs/1502.01852>`_\n\n    Args:\n        dtype (Dtype, optional): The data type of the array. Defaults to mx.float32.\n\n    Returns:\n        Callable[[array, str, float], array]: An initializer that returns an\n        array with the same shape as the input, filled with samples from the He\n        normal distribution.\n\n    Example:\n\n        >>> init_fn = nn.init.he_normal()\n        >>> init_fn(mx.zeros((2, 2)))  # uses fan_in\n        array([[-1.25211, 0.458835],\n               [-0.177208, -0.0137595]], dtype=float32)\n        >>> init_fn(mx.zeros((2, 2)), mode=\"fan_out\", gain=5)\n        array([[5.6967, 4.02765],\n               [-4.15268, -2.75787]], dtype=float32)\n    \"\"\"\n\ndef he_uniform(\n    dtype: mx.Dtype = ...,\n) -> Callable[[mx.array, Literal[\"fan_in\", \"fan_out\"], float], mx.array]:\n    r\"\"\"A He uniform (Kaiming uniform) initializer.\n\n    This initializer samples from a uniform distribution with a range\n    computed from the number of input (``fan_in``) or output (``fan_out``)\n    units according to:\n\n    .. math::\n\n        \\sigma = \\gamma \\sqrt{\\frac{3.0}{\\text{fan}}}\n\n    where :math:`\\text{fan}` is either the number of input units when the\n    ``mode`` is ``\"fan_in\"`` or output units when the ``mode`` is\n    ``\"fan_out\"``.\n\n    For more details see the original reference: `Delving Deep into Rectifiers:\n    Surpassing Human-Level Performance on ImageNet Classification\n    <https://arxiv.org/abs/1502.01852>`_\n\n\n    Args:\n        dtype (Dtype, optional): The data type of the array. Default: ``float32``.\n\n    Returns:\n        Callable[[array, str, float], array]: An initializer that returns an\n        array with the same shape as the input, filled with samples from  the\n        He uniform distribution.\n\n    Example:\n\n        >>> init_fn = nn.init.he_uniform()\n        >>> init_fn(mx.zeros((2, 2)))  # uses fan_in\n        array([[0.0300242, -0.0184009],\n               [0.793615, 0.666329]], dtype=float32)\n        >>> init_fn(mx.zeros((2, 2)), mode=\"fan_out\", gain=5)\n        array([[-1.64331, -2.16506],\n               [1.08619, 5.79854]], dtype=float32)\n    \"\"\"\n\ndef sparse(\n    sparsity: float, mean: float = ..., std: float = ..., dtype: mx.Dtype = ...\n) -> Callable[[mx.array], mx.array]:\n    r\"\"\"An initializer that returns a sparse matrix.\n\n    Args:\n        sparsity (float): The fraction of elements in each column to be set to\n        zero.\n        mean (float, optional): Mean of the normal distribution. Default:\n          ``0.0``.\n        std (float, optional): Standard deviation of the normal distribution.\n          Default: ``1.0``.\n        dtype (Dtype, optional): The data type of the array. Default:\n          ``float32``.\n\n    Returns:\n        Callable[[array], array]: An initializer that returns an array with the\n        same shape as the input, filled with samples from a normal distribution.\n\n    Example:\n\n        >>> init_fn = nn.init.sparse(sparsity=0.5)\n        >>> init_fn(mx.zeros((2, 2)))\n        array([[-1.91187, -0.117483],\n       [0, 0]], dtype=float32)\n    \"\"\"\n\ndef orthogonal(\n    gain: float = ..., dtype: mx.Dtype = ...\n) -> Callable[[mx.array], mx.array]:\n    r\"\"\"An initializer that returns an orthogonal matrix.\n\n    Args:\n        gain (float, optional): Scaling factor for the orthogonal matrix.\n            Default: ``1.0``.\n        dtype (Dtype, optional): Data type of the array. Default: ``float32``.\n\n    Returns:\n        Callable[[array], array]: An initializer that returns\n        an orthogonal matrix with the same shape as the input.\n    \"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/__init__.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom activations import *\nfrom base import *\nfrom containers import *\nfrom convolution import *\nfrom convolution_transpose import *\nfrom distributed import *\nfrom dropout import *\nfrom embedding import *\nfrom linear import *\nfrom normalization import *\nfrom pooling import *\nfrom positional_encoding import *\nfrom quantized import *\nfrom recurrent import *\nfrom transformer import *\nfrom upsample import *\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/activations.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom functools import partial\nfrom typing import Any\n\nimport mlx.core as mx\nfrom base import Module\n\n@partial(mx.compile, shapeless=True)\ndef sigmoid(x: mx.array) -> mx.array:\n    r\"\"\"Applies the sigmoid function.\n\n    .. math::\n        \\text{Sigmoid}(x) = \\sigma(x) = \\frac{1}{1 + \\exp(-x)}\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef relu(x: mx.array) -> mx.array:\n    r\"\"\"Applies the Rectified Linear Unit.\n\n    Simply ``mx.maximum(x, 0)``.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef relu2(x: mx.array) -> mx.array:\n    r\"\"\"Applies the ReLU² activation function.\n\n    Applies :math:`\\max(0, x)^2` element wise.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef relu6(x: mx.array) -> mx.array:\n    r\"\"\"Applies the Rectified Linear Unit 6.\n\n    Applies :math:`\\min(\\max(x, 0), 6)` element wise.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef leaky_relu(x: mx.array, negative_slope=...) -> mx.array:\n    r\"\"\"Applies the Leaky Rectified Linear Unit.\n\n    Simply ``mx.maximum(negative_slope * x, x)``.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef log_softmax(x: mx.array, axis=...):\n    r\"\"\"Applies the Log Softmax function.\n\n    Applies :math:`x + \\log \\sum_i e^{x_i}` element wise.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef elu(x: mx.array, alpha=...) -> mx.array:\n    r\"\"\"Applies the Exponential Linear Unit.\n\n    Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef softmax(x: mx.array, axis=...) -> mx.array:\n    r\"\"\"Applies the Softmax function.\n\n    Applies :math:`\\frac{e^{x_i}}{\\sum_j e^{x_j}}` element wise.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef softplus(x: mx.array) -> mx.array:\n    r\"\"\"Applies the Softplus function.\n\n    Applies :math:`\\log(1 + \\exp(x))` element wise.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef softsign(x: mx.array) -> mx.array:\n    r\"\"\"Applies the Softsign function.\n\n    Applies :math:`\\frac{x}{1 + |x|}` element wise.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef softshrink(x: mx.array, lambd: float = ...) -> mx.array:\n    r\"\"\"Applies the Softshrink activation function.\n\n    .. math::\n        \\text{softshrink}(x) = \\begin{cases}\n        x - \\lambda & \\text{if } x > \\lambda \\\\\n        x + \\lambda & \\text{if } x < -\\lambda \\\\\n        0 & \\text{otherwise}\n        \\end{cases}\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef celu(x: mx.array, alpha=...) -> mx.array:\n    r\"\"\"Applies the Continuously Differentiable Exponential Linear Unit.\n\n    Applies :math:`\\max(0, x) + \\min(0, \\alpha * (\\exp(x / \\alpha) - 1))`\n    element wise.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef silu(x: mx.array) -> mx.array:\n    r\"\"\"Applies the Sigmoid Linear Unit. Also known as Swish.\n\n    Applies :math:`x \\sigma(x)` element wise, where :math:`\\sigma(\\cdot)` is\n    the logistic sigmoid.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef log_sigmoid(x: mx.array) -> mx.array:\n    r\"\"\"Applies the Log Sigmoid function.\n\n    Applies :math:`\\log(\\sigma(x)) = -\\log(1 + e^{-x})` element wise.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef gelu(x: mx.array) -> mx.array:\n    r\"\"\"Applies the Gaussian Error Linear Units function.\n\n    .. math::\n        \\textrm{GELU}(x) = x * \\Phi(x)\n\n    where :math:`\\Phi(x)` is the Gaussian CDF.\n\n    See also :func:`gelu_approx` and :func:`gelu_fast_approx` for faster\n    approximations.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef gelu_approx(x: mx.array) -> mx.array:\n    r\"\"\"An approximation to Gaussian Error Linear Unit.\n\n    See :func:`gelu` for the exact computation.\n\n    This function approximates ``gelu`` with a maximum absolute error :math:`<\n    0.0005` in the range :math:`[-6, 6]` using the following\n\n    .. math::\n\n        x = 0.5 * x * \\left(1 + \\text{Tanh}\\left((\\sqrt{2 / \\pi} * \\left(x + 0.044715 * x^3\\right)\\right)\\right)\n\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef gelu_fast_approx(x: mx.array) -> mx.array:\n    r\"\"\"A fast approximation to Gaussian Error Linear Unit.\n\n    See :func:`gelu` for the exact computation.\n\n    This function approximates ``gelu`` with a maximum absolute error :math:`<\n    0.015` in the range :math:`[-6, 6]` using the following\n\n    .. math::\n\n        x = x \\sigma\\left(1.702 x\\right)\n\n    where :math:`\\sigma(\\cdot)` is the logistic sigmoid.\n\n    References:\n    - https://github.com/hendrycks/GELUs\n    - https://arxiv.org/abs/1606.08415\n    \"\"\"\n\ndef glu(x: mx.array, axis: int = ...) -> mx.array:\n    r\"\"\"Applies the gated linear unit function.\n\n    This function splits the ``axis`` dimension of the input into two halves\n    (:math:`a` and :math:`b`) and applies :math:`a * \\sigma(b)`.\n\n    .. math::\n        \\textrm{GLU}(x) = a * \\sigma(b)\n\n    Args:\n        axis (int): The dimension to split along. Default: ``-1``\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef step(x: mx.array, threshold: float = ...) -> mx.array:\n    r\"\"\"Applies the Step Activation Function.\n\n    This function implements a binary step activation, where the output is set\n    to 1 if the input is greater than a specified threshold, and 0 otherwise.\n\n    .. math::\n        \\text{step}(x) = \\begin{cases}\n        0 & \\text{if } x < \\text{threshold} \\\\\n        1 & \\text{if } x \\geq \\text{threshold}\n        \\end{cases}\n\n    Args:\n        threshold: The value to threshold at.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef selu(x: mx.array) -> mx.array:\n    r\"\"\"Applies the Scaled Exponential Linear Unit.\n\n    .. math::\n        \\text{selu}(x) = \\begin{cases}\n        \\lambda x & \\text{if } x > 0 \\\\\n        \\lambda \\alpha (\\exp(x) - 1) & \\text{if } x \\leq 0\n        \\end{cases}\n\n    where :math:`\\lambda = 1.0507` and :math:`\\alpha = 1.67326`.\n\n    See also :func:`elu`.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef prelu(x: mx.array, alpha: mx.array) -> mx.array:\n    r\"\"\"Applies the element-wise parametric ReLU.\n\n    .. math::\n        \\text{PReLU}(x) = \\max(0,x) + a * \\min(0,x)\n\n    where :math:`a` is an array.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef mish(x: mx.array) -> mx.array:\n    r\"\"\"Applies the Mish function, element-wise.\n\n    Mish: A Self Regularized Non-Monotonic Neural Activation Function.\n\n    Reference: https://arxiv.org/abs/1908.08681\n\n    .. math::\n        \\text{Mish}(x) = x * \\text{Tanh}(\\text{Softplus}(x))\n\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef hardswish(x: mx.array) -> mx.array:\n    r\"\"\"Applies the hardswish function, element-wise.\n\n    .. math::\n        \\text{Hardswish}(x) = x * \\min(\\max(x + 3, 0), 6) / 6\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef hard_tanh(x: mx.array, min_val=..., max_val=...) -> mx.array:\n    r\"\"\"Applies the HardTanh function.\n\n    Applies :math:`\\max(\\min(x, \\text{max\\_val}), \\text{min\\_val})` element-wise.\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef hard_shrink(x: mx.array, lambd=...) -> mx.array:\n    r\"\"\"Applies the HardShrink activation function.\n\n    .. math::\n        \\text{hardshrink}(x) = \\begin{cases}\n        x & \\text{if } x > \\lambda \\\\\n        x & \\text{if } x < -\\lambda \\\\\n        0 & \\text{otherwise}\n        \\end{cases}\n    \"\"\"\n\n@partial(mx.compile, shapeless=True)\ndef softmin(x: mx.array, axis=...) -> mx.array:\n    r\"\"\"Applies the Softmin function.\n\n    Applies :math:`\\frac{e^{-x_i}}{\\sum_j e^{-x_j}}` element-wise.\n    \"\"\"\n\ndef tanh(x: mx.array) -> mx.array:\n    \"\"\"Applies the hyperbolic tangent function.\n\n    Simply ``mx.tanh(x)``.\n    \"\"\"\n\nclass GLU(Module):\n    r\"\"\"Applies the gated linear unit function.\n\n    This function splits the ``axis`` dimension of the input into two halves\n    (:math:`a` and :math:`b`) and applies :math:`a * \\sigma(b)`.\n\n    .. math::\n        \\textrm{GLU}(x) = a * \\sigma(b)\n\n    Args:\n        axis (int): The dimension to split along. Default: ``-1``\n    \"\"\"\n    def __init__(self, axis: int = ...) -> None: ...\n    def __call__(self, x) -> Any: ...\n\n@_make_activation_module(sigmoid)\nclass Sigmoid(Module):\n    r\"\"\"Applies the sigmoid function, element-wise.\n\n    .. math::\n        \\text{Sigmoid}(x) = \\sigma(x) = \\frac{1}{1 + \\exp(-x)}\n    \"\"\"\n\n@_make_activation_module(mish)\nclass Mish(Module):\n    r\"\"\"Applies the Mish function, element-wise.\n\n    Reference: https://arxiv.org/abs/1908.08681\n\n    .. math::\n        \\text{Mish}(x) = x * \\text{Tanh}(\\text{Softplus}(x))\n\n    \"\"\"\n\n@_make_activation_module(relu)\nclass ReLU(Module):\n    r\"\"\"Applies the Rectified Linear Unit.\n        Simply ``mx.maximum(x, 0)``.\n\n    See :func:`relu` for the functional equivalent.\n    \"\"\"\n\n@_make_activation_module(relu2)\nclass ReLU2(Module):\n    r\"\"\"Applies the ReLU² activation function.\n\n    See :func:`relu2` for the functional equivalent.\n    \"\"\"\n\n@_make_activation_module(relu6)\nclass ReLU6(Module):\n    r\"\"\"Applies the Rectified Linear Unit 6.\n\n    See :func:`relu6` for the functional equivalent.\n    \"\"\"\n\nclass LeakyReLU(Module):\n    r\"\"\"Applies the Leaky Rectified Linear Unit.\n\n    Simply ``mx.maximum(negative_slope * x, x)``.\n\n    Args:\n        negative_slope: Controls the angle of the negative slope. Default: ``1e-2``\n    \"\"\"\n    def __init__(self, negative_slope=...) -> None: ...\n    def __call__(self, x): ...\n\nclass ELU(Module):\n    r\"\"\"Applies the Exponential Linear Unit.\n        Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.\n\n    See :func:`elu` for the functional equivalent.\n\n    Args:\n        alpha: the :math:`\\alpha` value for the ELU formulation. Default: ``1.0``\n    \"\"\"\n    def __init__(self, alpha=...) -> None: ...\n    def __call__(self, x): ...\n\n@_make_activation_module(softmax)\nclass Softmax(Module):\n    r\"\"\"Applies the Softmax function.\n\n    See :func:`softmax` for the functional equivalent.\n    \"\"\"\n\n@_make_activation_module(softplus)\nclass Softplus(Module):\n    r\"\"\"Applies the Softplus function.\n\n    See :func:`softplus` for the functional equivalent.\n    \"\"\"\n\n@_make_activation_module(softsign)\nclass Softsign(Module):\n    r\"\"\"Applies the Softsign function.\n\n    See :func:`softsign` for the functional equivalent.\n    \"\"\"\n\nclass Softshrink(Module):\n    r\"\"\"Applies the Softshrink function.\n\n    See :func:`softshrink` for the functional equivalent.\n\n    Args:\n        lambd: the :math:`\\lambda` value for Softshrink. Default: ``0.5``\n    \"\"\"\n    def __init__(self, lambd=...) -> None: ...\n    def __call__(self, x): ...\n\nclass CELU(Module):\n    r\"\"\"Applies the Continuously Differentiable Exponential Linear Unit.\n        Applies :math:`\\max(0, x) + \\min(0, \\alpha * (\\exp(x / \\alpha) - 1))`\n        element wise.\n\n    See :func:`celu` for the functional equivalent.\n\n    Args:\n        alpha: the :math:`\\alpha` value for the CELU formulation. Default: ``1.0``\n    \"\"\"\n    def __init__(self, alpha=...) -> None: ...\n    def __call__(self, x): ...\n\n@_make_activation_module(silu)\nclass SiLU(Module):\n    r\"\"\"Applies the Sigmoid Linear Unit. Also known as Swish.\n\n    See :func:`silu` for the functional equivalent.\n    \"\"\"\n\n@_make_activation_module(log_softmax)\nclass LogSoftmax(Module):\n    r\"\"\"Applies the Log Softmax function.\n\n    See :func:`log_softmax` for the functional equivalent.\n    \"\"\"\n\n@_make_activation_module(log_sigmoid)\nclass LogSigmoid(Module):\n    r\"\"\"Applies the Log Sigmoid function.\n\n    See :func:`log_sigmoid` for the functional equivalent.\n    \"\"\"\n\nclass PReLU(Module):\n    r\"\"\"Applies the element-wise parametric ReLU.\n        Applies :math:`\\max(0, x) + a * \\min(0, x)` element wise, where :math:`a`\n        is an array.\n\n    See :func:`prelu` for the functional equivalent.\n\n    Args:\n        num_parameters: number of :math:`a` to learn. Default: ``1``\n        init: the initial value of :math:`a`. Default: ``0.25``\n    \"\"\"\n    def __init__(self, num_parameters=..., init=...) -> None: ...\n    def __call__(self, x: mx.array): ...\n\nclass GELU(Module):\n    r\"\"\"Applies the Gaussian Error Linear Units.\n\n    .. math::\n        \\textrm{GELU}(x) = x * \\Phi(x)\n\n    where :math:`\\Phi(x)` is the Gaussian CDF.\n\n    However, if ``approx`` is set to 'precise' or 'fast' it applies\n\n    .. math::\n        \\textrm{GELUApprox}(x) &= 0.5 * x * \\left(1 + \\text{Tanh}\\left((\\sqrt{2 / \\pi} * \\left(x + 0.044715 * x^3\\right)\\right)\\right) \\\\\n        \\textrm{GELUFast}(x) &= x * \\sigma\\left(1.702 * x\\right)\n\n    respectively.\n\n    .. note::\n       For compatibility with the PyTorch API, 'tanh' can be used as an alias\n       for 'precise'.\n\n    See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the\n    functional equivalents and information regarding error bounds.\n\n\n    Args:\n        approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.\n    \"\"\"\n    def __init__(self, approx=...) -> None: ...\n    def __call__(self, x): ...\n\n@_make_activation_module(tanh)\nclass Tanh(Module):\n    r\"\"\"Applies the hyperbolic tangent function.\n\n    See :func:`tanh` for the functional equivalent.\n    \"\"\"\n\n@_make_activation_module(hardswish)\nclass Hardswish(Module):\n    r\"\"\"Applies the hardswish function, element-wise.\n\n    See :func:`hardswish` for the functional equivalent.\n    \"\"\"\n\nclass Step(Module):\n    r\"\"\"Applies the Step Activation Function.\n\n    This function implements a binary step activation, where the output is set\n    to 1 if the input is greater than a specified threshold, and 0 otherwise.\n\n    .. math::\n        \\text{step}(x) = \\begin{cases}\n        0 & \\text{if } x < \\text{threshold} \\\\\n        1 & \\text{if } x \\geq \\text{threshold}\n        \\end{cases}\n\n    Args:\n        threshold: The value to threshold at.\n    \"\"\"\n    def __init__(self, threshold: float = ...) -> None: ...\n    def __call__(self, x: mx.array): ...\n\n@_make_activation_module(selu)\nclass SELU(Module):\n    r\"\"\"Applies the Scaled Exponential Linear Unit.\n\n    See :func:`selu` for the functional equivalent.\n    \"\"\"\n\n@_make_activation_module(hard_tanh)\nclass HardTanh(Module):\n    r\"\"\"Applies the HardTanh function.\n\n    See :func:`hard_tanh` for the functional equivalent.\n    \"\"\"\n\n@_make_activation_module(hard_shrink)\nclass HardShrink(Module):\n    r\"\"\"Applies the HardShrink function.\n\n    See :func:`hard_shrink` for the functional equivalent.\n\n    Args:\n        lambd: the :math:`\\lambda` value for Hardshrink. Default: ``0.5``\n    \"\"\"\n\n@_make_activation_module(softmin)\nclass Softmin(Module):\n    r\"\"\"Applies the Softmin function.\n\n    See :func:`softmin` for the functional equivalent.\n    \"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/base.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Any, Callable, List, Optional, Tuple, Union\n\nimport mlx.core as mx\n\nclass Module(dict):\n    \"\"\"Base class for building neural networks with MLX.\n\n    All the layers provided in :mod:`layers` subclass this class and\n    your models should do the same.\n\n    A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`\n    instances in arbitrary nesting of python lists or dicts. The ``Module``\n    then allows recursively extracting all the :class:`mlx.core.array` instances\n    using :meth:`Module.parameters`.\n\n    In addition, the ``Module`` has the concept of trainable and non trainable\n    parameters (called \"frozen\"). When using :func:`value_and_grad`\n    the gradients are returned only with respect to the trainable parameters.\n    All arrays in a module are trainable unless they are added in the \"frozen\"\n    set by calling :meth:`freeze`.\n\n    .. code-block:: python\n\n        import mlx.core as mx\n        import mlx.nn as nn\n\n        class MyMLP(nn.Module):\n            def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):\n                super().__init__()\n\n                self.in_proj = nn.Linear(in_dims, hidden_dims)\n                self.out_proj = nn.Linear(hidden_dims, out_dims)\n\n            def __call__(self, x):\n                x = self.in_proj(x)\n                x = mx.maximum(x, 0)\n                return self.out_proj(x)\n\n        model = MyMLP(2, 1)\n\n        # All the model parameters are created but since MLX is lazy by\n        # default, they are not evaluated yet. Calling `mx.eval` actually\n        # allocates memory and initializes the parameters.\n        mx.eval(model.parameters())\n\n        # Setting a parameter to a new value is as simply as accessing that\n        # parameter and assigning a new array to it.\n        model.in_proj.weight = model.in_proj.weight * 2\n        mx.eval(model.parameters())\n    \"\"\"\n\n    __call__: Callable\n    def __init__(self) -> None:\n        \"\"\"Should be called by the subclasses of ``Module``.\"\"\"\n\n    @property\n    def training(self):  # -> bool:\n        \"\"\"Boolean indicating if the model is in training mode.\"\"\"\n\n    @property\n    def state(self):  # -> Self:\n        \"\"\"The module's state dictionary\n\n        The module's state dictionary contains any attribute set on the\n        module including parameters in :meth:`Module.parameters`\n\n        Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is\n        a reference to the module's state. Updates to it will be reflected in\n        the original module.\n        \"\"\"\n\n    def __repr__(self):  # -> str:\n        ...\n    def __getattr__(self, key: str):  # -> None:\n        ...\n    def __setattr__(self, key: str, val: Any):  # -> None:\n        ...\n    def __delattr__(self, name):  # -> None:\n        ...\n    def load_weights(\n        self,\n        file_or_weights: Union[str, List[Tuple[str, mx.array]]],\n        strict: bool = ...,\n    ) -> Module:\n        \"\"\"\n        Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.\n\n        Args:\n            file_or_weights (str or list(tuple(str, mx.array))): The path to\n                the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list\n                of pairs of parameter names and arrays.\n            strict (bool, optional): If ``True`` then checks that the provided\n              weights exactly match the parameters of the model. Otherwise,\n              only the weights actually contained in the model are loaded and\n              shapes are not checked. Default: ``True``.\n\n        Returns:\n            The module instance after updating the weights.\n\n        Example:\n\n            .. code-block:: python\n\n                import mlx.core as mx\n                import mlx.nn as nn\n                model = nn.Linear(10, 10)\n\n                # Load from file\n                model.load_weights(\"weights.npz\")\n\n                # Load from .safetensors file\n                model.load_weights(\"weights.safetensors\")\n\n                # Load from list\n                weights = [\n                    (\"weight\", mx.random.uniform(shape=(10, 10))),\n                    (\"bias\",  mx.zeros((10,))),\n                ]\n                model.load_weights(weights)\n\n                # Missing weight\n                weights = [\n                    (\"weight\", mx.random.uniform(shape=(10, 10))),\n                ]\n\n                # Raises a ValueError exception\n                model.load_weights(weights)\n\n                # Ok, only updates the weight but not the bias\n                model.load_weights(weights, strict=False)\n        \"\"\"\n\n    def save_weights(self, file: str):  # -> None:\n        \"\"\"\n        Save the model's weights to a file. The saving method is determined by the file extension:\n        - ``.npz`` will use :func:`mx.savez`\n        - ``.safetensors`` will use :func:`mx.save_safetensors`\n        \"\"\"\n\n    @staticmethod\n    def is_module(value):  # -> bool:\n        ...\n    @staticmethod\n    def valid_child_filter(module, key, value):  # -> bool:\n        ...\n    @staticmethod\n    def valid_parameter_filter(module, key, value):  # -> bool:\n        ...\n    @staticmethod\n    def trainable_parameter_filter(module, key, value):  # -> bool:\n        ...\n    def filter_and_map(\n        self,\n        filter_fn: Callable[[Module, str, Any], bool],\n        map_fn: Optional[Callable] = ...,\n        is_leaf_fn: Optional[Callable[[Module, str, Any], bool]] = ...,\n    ):  # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:\n        \"\"\"Recursively filter the contents of the module using ``filter_fn``,\n        namely only select keys and values where ``filter_fn`` returns true.\n\n        This is used to implement :meth:`parameters` and :meth:`trainable_parameters`\n        but it can also be used to extract any subset of the module's parameters.\n\n        Args:\n            filter_fn (Callable): Given a value, the key in which it is found\n                and the containing module, decide whether to keep the value or\n                drop it.\n            map_fn (Callable, optional): Optionally transform the value before\n                returning it.\n            is_leaf_fn (Callable, optional): Given a value, the key in which it\n                is found and the containing module decide if it is a leaf.\n\n        Returns:\n            A dictionary containing the contents of the module recursively filtered\n        \"\"\"\n\n    def parameters(\n        self,\n    ) -> mx.MX_ARRAY_TREE:\n        \"\"\"Recursively return all the :class:`mlx.core.array` members of this Module\n        as a dict of dicts and lists.\"\"\"\n\n    def trainable_parameters(\n        self,\n    ) -> mx.MX_ARRAY_TREE:  # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:\n        \"\"\"Recursively return all the non frozen :class:`mlx.core.array` members of\n        this Module as a dict of dicts and lists.\"\"\"\n\n    def children(\n        self,\n    ) -> mx.MX_ARRAY_TREE:  # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:\n        \"\"\"Return the direct descendants of this Module instance.\"\"\"\n\n    def leaf_modules(\n        self,\n    ) -> mx.MX_ARRAY_TREE:  # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:\n        \"\"\"Return the submodules that do not contain other modules.\"\"\"\n\n    def update(self, parameters: dict[str, Any], strict: bool = ...) -> Module:\n        \"\"\"Replace the parameters of this Module with the provided ones in the\n        dict of dicts and lists.\n\n        Commonly used by the optimizer to change the model to the updated\n        (optimized) parameters. Also used by the :meth:`value_and_grad` to set the\n        tracers in the model in order to compute gradients.\n\n        The passed in parameters dictionary need not be a full dictionary\n        similar to :meth:`parameters`. Only the provided locations will be\n        updated.\n\n        Args:\n            parameters (dict): A complete or partial dictionary of the modules\n                parameters.\n            strict (bool): If ``True`` checks that ``parameters`` is a\n                subset of the module's parameters. Default: ``True``.\n        Returns:\n            The module instance after updating the parameters.\n        \"\"\"\n\n    def apply(\n        self,\n        map_fn: Callable[[mx.array], mx.array],\n        filter_fn: Optional[Callable[[Module, str, Any], bool]] = ...,\n    ) -> Module:\n        \"\"\"Map all the parameters using the provided ``map_fn`` and immediately\n        update the module with the mapped parameters.\n\n        For instance running ``model.apply(lambda x: x.astype(mx.float16))``\n        casts all parameters to 16 bit floats.\n\n        Args:\n            map_fn (Callable): Maps an array to another array\n            filter_fn (Callable, optional): Filter to select which arrays to\n                map (default: :meth:`Module.valid_parameter_filter`).\n\n        Returns:\n            The module instance after updating the parameters.\n        \"\"\"\n\n    def update_modules(self, modules: dict, strict: bool = ...) -> Module:\n        \"\"\"Replace the child modules of this :class:`Module` instance with the\n        provided ones in the dict of dicts and lists.\n\n        It is the equivalent of :meth:`Module.update` but for modules instead\n        of parameters and allows us to flexibly edit complex architectures by\n        programmatically swapping layers.\n\n        The passed in parameters dictionary need not be a full dictionary\n        similar to :meth:`modules`. Only the provided locations will be\n        updated.\n\n        Args:\n            modules (dict): A complete or partial dictionary of the module's\n                submodules.\n            strict (bool): If ``True`` checks that ``modules`` is a\n                subset of the child modules of this instance. Default: ``True``.\n        Returns:\n            The module instance after updating the submodules.\n        \"\"\"\n\n    def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:\n        \"\"\"Apply a function to all the modules in this instance (including this\n        instance).\n\n        Args:\n            apply_fn (Callable): The function to apply to the modules.\n\n        Returns:\n            The module instance after updating submodules.\n        \"\"\"\n\n    def modules(self):  # -> list[Any]:\n        \"\"\"Return a list with all the modules in this instance.\n\n        Returns:\n            A list of :class:`Module` instances.\n        \"\"\"\n\n    def named_modules(self):  # -> list[Any]:\n        \"\"\"Return a list with all the modules in this instance and their name\n        with dot notation.\n\n        Returns:\n            A list of tuples (str, :class:`Module`).\n        \"\"\"\n\n    def freeze(\n        self,\n        *,\n        recurse: bool = ...,\n        keys: Optional[Union[str, List[str]]] = ...,\n        strict: bool = ...,\n    ) -> Module:\n        \"\"\"Freeze the Module's parameters or some of them. Freezing a parameter means not\n        computing gradients for it.\n\n        This function is idempotent i.e. freezing a frozen model is a no-op.\n\n        Example:\n            For instance to only train the attention parameters from a Transformer:\n\n            .. code-block:: python\n\n                model = nn.Transformer()\n                model.freeze()\n                model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith(\"attention\") else None)\n\n        Args:\n            recurse (bool, optional): If True then freeze the parameters of the\n                submodules as well. Default: ``True``.\n            keys (str or list[str], optional): If provided then only these\n                parameters will be frozen otherwise all the parameters of a\n                module. For instance freeze all biases by calling\n                ``module.freeze(keys=\"bias\")``.\n            strict (bool, optional): If set to ``True`` validate that the passed keys exist.\n                Default: ``False``.\n\n        Returns:\n            The module instance after freezing the parameters.\n        \"\"\"\n\n    def unfreeze(\n        self,\n        *,\n        recurse: bool = ...,\n        keys: Optional[Union[str, List[str]]] = ...,\n        strict: bool = ...,\n    ) -> Module:\n        \"\"\"Unfreeze the Module's parameters or some of them.\n\n        This function is idempotent ie unfreezing a model that is not frozen is\n        a noop.\n\n        Example:\n\n            For instance to only train the biases of a Transformer one can do:\n\n            .. code-block:: python\n\n                model = nn.Transformer()\n                model.freeze()\n                model.unfreeze(keys=\"bias\")\n\n        Args:\n            recurse (bool, optional): If True then unfreeze the parameters of the\n                submodules as well. Default: ``True``.\n            keys (str or list[str], optional): If provided then only these\n                parameters will be unfrozen otherwise all the parameters of a\n                module. For instance unfreeze all biases by calling\n                ``module.unfreeze(keys=\"bias\")``.\n            strict (bool, optional): If set to ``True`` validate that the passed keys exist.\n                Default: ``False``.\n\n        Returns:\n            The module instance after unfreezing the parameters.\n        \"\"\"\n\n    def train(self, mode: bool = ...) -> Module:\n        \"\"\"Set the model in or out of training mode.\n\n        Training mode only applies to certain layers. For example\n        :obj:`Dropout` applies a random mask in training mode, but is the\n        identity in evaluation mode.\n\n        Args:\n            mode (bool): Indicate if the model should be in training or\n                evaluation mode. Default: ``True``.\n        Returns:\n            The module instance after updating the training mode.\n        \"\"\"\n\n    def eval(self) -> Module:\n        \"\"\"Set the model to evaluation mode.\n\n        See :func:`train`.\n        \"\"\"\n\n    def set_dtype(\n        self, dtype: mx.Dtype, predicate: Optional[Callable[[mx.Dtype], bool]] = ...\n    ):  # -> None:\n        \"\"\"Set the dtype of the module's parameters.\n\n        Args:\n            dtype (Dtype): The new dtype.\n            predicate (typing.Callable, optional): A predicate to select\n              parameters to cast. By default, only parameters of type\n              :attr:`floating` will be updated to avoid casting integer\n              parameters to the new dtype.\n        \"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/containers.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Callable\n\nimport mlx.core as mx\nfrom base import Module\n\nclass Sequential(Module):\n    \"\"\"A layer that calls the passed callables in order.\n\n    We can pass either modules or plain callables to the Sequential module. If\n    our functions have learnable parameters they should be implemented as\n    ``nn.Module`` instances.\n\n    Args:\n        modules (tuple of Callables): The modules to call in order\n    \"\"\"\n    def __init__(self, *modules: Module | Callable[[mx.array], mx.array]) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/convolution.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Union\n\nimport mlx.core as mx\nfrom base import Module\n\nclass Conv1d(Module):\n    \"\"\"Applies a 1-dimensional convolution over the multi-channel input sequence.\n\n    The channels are expected to be last i.e. the input shape should be ``NLC`` where:\n\n    * ``N`` is the batch dimension\n    * ``L`` is the sequence length\n    * ``C`` is the number of input channels\n\n    Args:\n        in_channels (int): The number of input channels\n        out_channels (int): The number of output channels\n        kernel_size (int): The size of the convolution filters\n        stride (int, optional): The stride when applying the filter.\n            Default: ``1``.\n        padding (int, optional): How many positions to 0-pad the input with.\n            Default: ``0``.\n        dilation (int, optional): The dilation of the convolution.\n        groups (int, optional): The number of groups for the convolution.\n            Default: ``1``.\n        bias (bool, optional): If ``True`` add a learnable bias to the output.\n            Default: ``True``\n    \"\"\"\n\n    weight: mx.array\n    bias: mx.array | None\n    groups: int\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int,\n        stride: int = ...,\n        padding: int = ...,\n        dilation: int = ...,\n        groups: int = ...,\n        bias: bool = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass Conv2d(Module):\n    \"\"\"Applies a 2-dimensional convolution over the multi-channel input image.\n\n    The channels are expected to be last i.e. the input shape should be ``NHWC`` where:\n\n    * ``N`` is the batch dimension\n    * ``H`` is the input image height\n    * ``W`` is the input image width\n    * ``C`` is the number of input channels\n\n    Args:\n        in_channels (int): The number of input channels.\n        out_channels (int): The number of output channels.\n        kernel_size (int or tuple): The size of the convolution filters.\n        stride (int or tuple, optional): The size of the stride when\n            applying the filter. Default: ``1``.\n        padding (int or tuple, optional): How many positions to 0-pad\n            the input with. Default: ``0``.\n        dilation (int or tuple, optional): The dilation of the convolution.\n        groups (int, optional): The number of groups for the convolution.\n            Default: ``1``.\n        bias (bool, optional): If ``True`` add a learnable bias to the\n            output. Default: ``True``\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, tuple],\n        stride: Union[int, tuple] = ...,\n        padding: Union[int, tuple] = ...,\n        dilation: Union[int, tuple] = ...,\n        groups: int = ...,\n        bias: bool = ...,\n    ) -> None: ...\n    def __call__(self, x) -> mx.array: ...\n\nclass Conv3d(Module):\n    \"\"\"Applies a 3-dimensional convolution over the multi-channel input image.\n\n    The channels are expected to be last i.e. the input shape should be ``NDHWC`` where:\n\n    * ``N`` is the batch dimension\n    * ``D`` is the input image depth\n    * ``H`` is the input image height\n    * ``W`` is the input image width\n    * ``C`` is the number of input channels\n\n    Args:\n        in_channels (int): The number of input channels.\n        out_channels (int): The number of output channels.\n        kernel_size (int or tuple): The size of the convolution filters.\n        stride (int or tuple, optional): The size of the stride when\n            applying the filter. Default: ``1``.\n        dilation (int or tuple, optional): The dilation of the convolution.\n        padding (int or tuple, optional): How many positions to 0-pad\n            the input with. Default: ``0``.\n        bias (bool, optional): If ``True`` add a learnable bias to the\n            output. Default: ``True``\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, tuple],\n        stride: Union[int, tuple] = ...,\n        padding: Union[int, tuple] = ...,\n        dilation: Union[int, tuple] = ...,\n        bias: bool = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/convolution_transpose.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Union\n\nimport mlx.core as mx\nfrom base import Module\n\nclass ConvTranspose1d(Module):\n    \"\"\"Applies a 1-dimensional transposed convolution over the multi-channel input sequence.\n\n    The channels are expected to be last i.e. the input shape should be ``NLC`` where:\n\n    * ``N`` is the batch dimension\n    * ``L`` is the sequence length\n    * ``C`` is the number of input channels\n\n    Args:\n        in_channels (int): The number of input channels\n        out_channels (int): The number of output channels\n        kernel_size (int): The size of the convolution filters\n        stride (int, optional): The stride when applying the filter.\n            Default: ``1``.\n        padding (int, optional): How many positions to 0-pad the input with.\n            Default: ``0``.\n        dilation (int, optional): The dilation of the convolution.\n        output_padding(int, optional): Additional size added to one side of the\n            output shape. Default: ``0``.\n        bias (bool, optional): If ``True`` add a learnable bias to the output.\n            Default: ``True``\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int,\n        stride: int = ...,\n        padding: int = ...,\n        dilation: int = ...,\n        output_padding: int = ...,\n        bias: bool = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass ConvTranspose2d(Module):\n    \"\"\"Applies a 2-dimensional transposed convolution over the multi-channel input image.\n\n    The channels are expected to be last i.e. the input shape should be ``NHWC`` where:\n\n    * ``N`` is the batch dimension\n    * ``H`` is the input image height\n    * ``W`` is the input image width\n    * ``C`` is the number of input channels\n\n    Args:\n        in_channels (int): The number of input channels.\n        out_channels (int): The number of output channels.\n        kernel_size (int or tuple): The size of the convolution filters.\n        stride (int or tuple, optional): The size of the stride when\n            applying the filter. Default: ``1``.\n        padding (int or tuple, optional): How many positions to 0-pad\n            the input with. Default: ``0``.\n        dilation (int or tuple, optional): The dilation of the convolution.\n        output_padding(int or tuple, optional): Additional size added to one\n            side of the output shape. Default: ``0``.\n        bias (bool, optional): If ``True`` add a learnable bias to the\n            output. Default: ``True``\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, tuple],\n        stride: Union[int, tuple] = ...,\n        padding: Union[int, tuple] = ...,\n        dilation: Union[int, tuple] = ...,\n        output_padding: Union[int, tuple] = ...,\n        bias: bool = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass ConvTranspose3d(Module):\n    \"\"\"Applies a 3-dimensional transposed convolution over the multi-channel input image.\n\n    The channels are expected to be last i.e. the input shape should be ``NDHWC`` where:\n\n    * ``N`` is the batch dimension\n    * ``D`` is the input image depth\n    * ``H`` is the input image height\n    * ``W`` is the input image width\n    * ``C`` is the number of input channels\n\n    Args:\n        in_channels (int): The number of input channels.\n        out_channels (int): The number of output channels.\n        kernel_size (int or tuple): The size of the convolution filters.\n        stride (int or tuple, optional): The size of the stride when\n            applying the filter. Default: ``1``.\n        padding (int or tuple, optional): How many positions to 0-pad\n            the input with. Default: ``0``.\n        dilation (int or tuple, optional): The dilation of the convolution.\n        output_padding(int or tuple, optional): Additional size added to one\n            side of the output shape. Default: ``0``.\n        bias (bool, optional): If ``True`` add a learnable bias to the\n            output. Default: ``True``\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Union[int, tuple],\n        stride: Union[int, tuple] = ...,\n        padding: Union[int, tuple] = ...,\n        dilation: Union[int, tuple] = ...,\n        output_padding: Union[int, tuple] = ...,\n        bias: bool = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/distributed.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom functools import lru_cache\nfrom typing import Callable, Optional, Union\n\nimport mlx.core as mx\nfrom base import Module\nfrom mlx.nn.layers.linear import Linear\n\n@lru_cache\ndef sum_gradients(\n    group: mx.distributed.Group,\n) -> Callable[..., mx.array]:  # -> Callable[..., Any] | Callable[..., array]:\n    ...\ndef shard_inplace(\n    module: Module,\n    sharding: str,\n    *,\n    segments: Union[int, list[int]] = ...,\n    group: Optional[mx.distributed.Group] = ...,\n) -> None:\n    \"\"\"Shard a module in-place by updating its parameter dictionary with the\n    sharded parameter dictionary.\n\n    The ``sharding`` argument can be any callable that given the path and the\n    weight returns the sharding axis and optionally also the segments that\n    comprise the unsharded weight. For instance if the weight is a fused QKV\n    matrix the segments should be 3.\n\n    .. note::\n        The module doesn't change so in order for distributed communication to\n        happen the module needs to natively support it and for it to be enabled.\n\n    Args:\n        module (Module): The parameters of this module will be sharded\n            in-place.\n        sharding (str or callable): One of \"all-to-sharded\" and\n            \"sharded-to-all\" or a callable that returns the sharding axis and\n            segments.\n        segments (int or list): The segments to use if ``sharding`` is a\n            string. Default: ``1``.\n        group (mlx.core.distributed.Group): The distributed group to shard\n            across. If not set, the global group will be used. Default: ``None``.\n    \"\"\"\n\ndef shard_linear(\n    module: Module,\n    sharding: str,\n    *,\n    segments: Union[int, list[int]] = ...,\n    group: Optional[mx.distributed.Group] = ...,\n) -> Linear:\n    \"\"\"Create a new linear layer that has its parameters sharded and also\n    performs distributed communication either in the forward or backward\n    pass.\n\n    .. note::\n        Contrary to ``shard_inplace``, the original layer is not changed but a\n        new layer is returned.\n\n    Args:\n        module (Module): The linear layer to be sharded.\n        sharding (str): One of \"all-to-sharded\" and\n            \"sharded-to-all\" that defines the type of sharding to perform.\n        segments (int or list): The segments to use. Default: ``1``.\n        group (mlx.core.distributed.Group): The distributed group to shard\n            across. If not set, the global group will be used. Default: ``None``.\n    \"\"\"\n\nclass AllToShardedLinear(Module):\n    \"\"\"Each member of the group applies part of the affine transformation such\n    that the result is sharded across the group.\n\n    The gradients are automatically aggregated from each member of the group.\n\n    Args:\n        input_dims (int): The dimensionality of the input features\n        output_dims (int): The dimensionality of the output features\n        bias (bool, optional): If set to ``False`` the the layer will not use a\n            bias. Default is ``True``.\n        group (mx.distributed.Group, optional): The sharding will happen across\n            this group. If not set then the global group is used. Default is\n            ``None``.\n    \"\"\"\n    def __init__(\n        self,\n        input_dims: int,\n        output_dims: int,\n        bias: bool = ...,\n        group: Optional[mx.distributed.Group] = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n    @classmethod\n    def from_linear(\n        cls,\n        linear_layer: Module,\n        *,\n        segments: Union[int, list[int]] = ...,\n        group: Optional[mx.distributed.Group] = ...,\n    ) -> AllToShardedLinear: ...\n\nclass ShardedToAllLinear(Module):\n    \"\"\"Each member of the group applies part of the affine transformation and\n    then aggregates the results.\n\n    All nodes will have the same exact result after this layer.\n\n    :class:`ShardedToAllLinear` provides a classmethod :meth:`from_linear` to\n    convert linear layers to sharded :obj:`ShardedToAllLinear` layers.\n\n    Args:\n        input_dims (int): The dimensionality of the input features\n        output_dims (int): The dimensionality of the output features\n        bias (bool, optional): If set to ``False`` the the layer will not use a\n            bias. Default is ``True``.\n        group (mx.distributed.Group, optional): The sharding will happen across\n            this group. If not set then the global group is used. Default is\n            ``None``.\n    \"\"\"\n    def __init__(\n        self,\n        input_dims: int,\n        output_dims: int,\n        bias: bool = ...,\n        group: Optional[mx.distributed.Group] = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n    @classmethod\n    def from_linear(\n        cls,\n        linear_layer: Module,\n        *,\n        segments: Union[int, list[int]] = ...,\n        group: Optional[mx.distributed.Group] = ...,\n    ) -> ShardedToAllLinear: ...\n\nclass QuantizedAllToShardedLinear(Module):\n    \"\"\"Each member of the group applies part of the affine transformation with\n    a quantized matrix such that the result is sharded across the group.\n\n    It is the quantized equivalent of :class:`AllToShardedLinear`.\n    Similar to :class:`QuantizedLinear` its parameters are frozen and\n    will not be included in any gradient computation.\n\n    Args:\n        input_dims (int): The dimensionality of the input features.\n        output_dims (int): The dimensionality of the output features.\n        bias (bool, optional): If set to ``False`` then the layer will not use\n            a bias. Default: ``True``.\n        group_size (int, optional): The group size to use for the quantized\n            weight. See :func:`~mlx.core.quantize`. Default: ``64``.\n        bits (int, optional): The bit width to use for the quantized weight.\n            See :func:`~mlx.core.quantize`. Default: ``4``.\n        group (mx.distributed.Group, optional): The sharding will happen across\n            this group. If not set then the global group is used. Default is\n            ``None``.\n    \"\"\"\n    def __init__(\n        self,\n        input_dims: int,\n        output_dims: int,\n        bias: bool = ...,\n        group_size: int = ...,\n        bits: int = ...,\n        group: Optional[mx.distributed.Group] = ...,\n    ) -> None: ...\n    def unfreeze(self, *args, **kwargs) -> None:\n        \"\"\"Wrap unfreeze so that we unfreeze any layers we might contain but\n        our parameters will remain frozen.\"\"\"\n\n    def __call__(self, x: mx.array) -> mx.array: ...\n    @classmethod\n    def from_quantized_linear(\n        cls,\n        quantized_linear_layer: Module,\n        *,\n        segments: Union[int, list[int]] = ...,\n        group: Optional[mx.distributed.Group] = ...,\n    ) -> QuantizedAllToShardedLinear: ...\n\nclass QuantizedShardedToAllLinear(Module):\n    \"\"\"Each member of the group applies part of the affine transformation using\n    the quantized matrix and then aggregates the results.\n\n    All nodes will have the same exact result after this layer.\n\n    It is the quantized equivalent of :class:`ShardedToAllLinear`.\n    Similar to :class:`QuantizedLinear` its parameters are frozen and\n    will not be included in any gradient computation.\n\n    Args:\n        input_dims (int): The dimensionality of the input features.\n        output_dims (int): The dimensionality of the output features.\n        bias (bool, optional): If set to ``False`` then the layer will not use\n            a bias. Default: ``True``.\n        group_size (int, optional): The group size to use for the quantized\n            weight. See :func:`~mlx.core.quantize`. Default: ``64``.\n        bits (int, optional): The bit width to use for the quantized weight.\n            See :func:`~mlx.core.quantize`. Default: ``4``.\n        group (mx.distributed.Group, optional): The sharding will happen across\n            this group. If not set then the global group is used. Default is\n            ``None``.\n    \"\"\"\n    def __init__(\n        self,\n        input_dims: int,\n        output_dims: int,\n        bias: bool = ...,\n        group_size: int = ...,\n        bits: int = ...,\n        group: Optional[mx.distributed.Group] = ...,\n    ) -> None: ...\n    def unfreeze(self, *args, **kwargs):  # -> None:\n        \"\"\"Wrap unfreeze so that we unfreeze any layers we might contain but\n        our parameters will remain frozen.\"\"\"\n\n    def __call__(self, x: mx.array) -> mx.array: ...\n    @classmethod\n    def from_quantized_linear(\n        cls,\n        quantized_linear_layer: Module,\n        *,\n        segments: Union[int, list[int]] = ...,\n        group: Optional[mx.distributed.Group] = ...,\n    ) -> QuantizedShardedToAllLinear: ...\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/dropout.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom base import Module\n\nclass Dropout(Module):\n    r\"\"\"Randomly zero a portion of the elements during training.\n\n    The remaining elements are multiplied with :math:`\\frac{1}{1-p}` where\n    :math:`p` is the probability of zeroing an element. This is done so the\n    expected value of a given element will remain the same.\n\n    Args:\n        p (float): The probability to zero an element\n    \"\"\"\n    def __init__(self, p: float = ...) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass Dropout2d(Module):\n    r\"\"\"Apply 2D channel-wise dropout during training.\n\n    Randomly zero out entire channels independently with probability :math:`p`.\n    This layer expects the channels to be last, i.e. the input shape should be\n    ``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input\n    image height,``W`` is the input image width, and``C`` is the number of\n    input channels\n\n    The remaining channels are scaled by :math:`\\frac{1}{1-p}` to\n    maintain the expected value of each element. Unlike traditional dropout,\n    which zeros individual entries, this layer zeros entire channels. This is\n    beneficial for early convolution layers where adjacent pixels are\n    correlated. In such case, traditional dropout may not effectively\n    regularize activations. For more details, see [1].\n\n    [1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015.\n    Efficient Object Localization Using Convolutional Networks. CVPR 2015.\n\n    Args:\n        p (float): Probability of zeroing a channel during training.\n    \"\"\"\n    def __init__(self, p: float = ...) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass Dropout3d(Module):\n    r\"\"\"Apply 3D channel-wise dropout during training.\n\n    Randomly zero out entire channels independently with probability :math:`p`.\n    This layer expects the channels to be last, i.e., the input shape should be\n    `NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,\n    `H` is the input image height, `W` is the input image width, and `C` is\n    the number of input channels.\n\n    The remaining channels are scaled by :math:`\\frac{1}{1-p}` to\n    maintain the expected value of each element. Unlike traditional dropout,\n    which zeros individual entries, this layer zeros entire channels. This is\n    often beneficial for convolutional layers processing 3D data, like in\n    medical imaging or video processing.\n\n    Args:\n        p (float): Probability of zeroing a channel during training.\n    \"\"\"\n    def __init__(self, p: float = ...) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/embedding.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom base import Module\n\nfrom .quantized import QuantizedEmbedding\n\nclass Embedding(Module):\n    \"\"\"Implements a simple lookup table that maps each input integer to a\n    high-dimensional vector.\n\n    Typically used to embed discrete tokens for processing by neural networks.\n\n    Args:\n        num_embeddings (int): How many possible discrete tokens can we embed.\n           Usually called the vocabulary size.\n        dims (int): The dimensionality of the embeddings.\n    \"\"\"\n    def __init__(self, num_embeddings: int, dims: int) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n    def as_linear(self, x: mx.array) -> mx.array:\n        \"\"\"\n        Call the embedding layer as a linear layer.\n\n        Use this for example when input embedding and output projection\n        weights are tied.\n        \"\"\"\n\n    def to_quantized(\n        self, group_size: int = ..., bits: int = ..., mode: str = ...\n    ) -> QuantizedEmbedding:\n        \"\"\"Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer.\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/linear.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Any\n\nimport mlx.core as mx\nfrom base import Module\n\nfrom .quantized import QuantizedLinear\n\nclass Identity(Module):\n    r\"\"\"A placeholder identity operator that is argument-insensitive.\n\n    Args:\n        args: any argument (unused)\n        kwargs: any keyword argument (unused)\n    \"\"\"\n    def __init__(self, *args: Any, **kwargs: Any) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass Linear(Module):\n    r\"\"\"Applies an affine transformation to the input.\n\n    Concretely:\n\n    .. math::\n\n        y = x W^\\top + b\n\n    where:\n    where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``.\n\n    The values are initialized from the uniform distribution :math:`\\mathcal{U}(-{k}, {k})`,\n    where :math:`k = \\frac{1}{\\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``.\n\n    Args:\n        input_dims (int): The dimensionality of the input features\n        output_dims (int): The dimensionality of the output features\n        bias (bool, optional): If set to ``False`` then the layer will\n          not use a bias. Default is ``True``.\n    \"\"\"\n\n    weight: mx.array\n    bias: mx.array | None\n\n    def __init__(self, input_dims: int, output_dims: int, bias: bool = ...) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n    def to_quantized(\n        self, group_size: int = ..., bits: int = ..., mode: str = ...\n    ) -> QuantizedLinear:\n        \"\"\"Return a :obj:`QuantizedLinear` layer that approximates this layer.\"\"\"\n\nclass Bilinear(Module):\n    r\"\"\"Applies a bilinear transformation to the inputs.\n\n    Concretely:\n\n    .. math::\n\n        y_i = x_1^\\top W_i x_2 + b_i\n\n    where:\n    :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,\n    and :math:`i` indexes the output dimension.\n\n    The values are initialized from the uniform distribution :math:`\\mathcal{U}(-{k}, {k})`,\n    where :math:`k = \\frac{1}{\\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``.\n\n    Args:\n        input1_dims (int): The dimensionality of the input1 features\n        input2_dims (int): The dimensionality of the input2 features\n        output_dims (int): The dimensionality of the output features\n        bias (bool, optional): If set to ``False`` then the layer will\n          not use a bias. Default is ``True``.\n    \"\"\"\n    def __init__(\n        self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = ...\n    ) -> None: ...\n    def __call__(self, x1: mx.array, x2: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/normalization.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.core as mx\nfrom base import Module\n\nclass InstanceNorm(Module):\n    r\"\"\"Applies instance normalization [1] on the inputs.\n\n    Computes\n\n    .. math::\n\n        y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta,\n\n    where :math:`\\gamma` and :math:`\\beta` are learned per feature dimension\n    parameters initialized at 1 and 0 respectively. Both are of size :attr:`dims`,\n    if :attr:`affine` is ``True``.\n\n    Args:\n        dims (int): The number of features of the input.\n        eps (float): A value added to the denominator for numerical stability. Default: ``1e-5``.\n        affine (bool): Default: ``False``.\n\n    Shape:\n      - Input: :math:`(..., C)` where :math:`C` is equal to :attr:`dims`.\n      - Output: Same shape as the input.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn as nn\n        >>> x = mx.random.normal((8, 4, 4, 16))\n        >>> inorm = nn.InstanceNorm(dims=16)\n        >>> output = inorm(x)\n\n    References:\n        [1]: https://arxiv.org/abs/1607.08022\n    \"\"\"\n    def __init__(self, dims: int, eps: float = ..., affine: bool = ...) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass LayerNorm(Module):\n    r\"\"\"Applies layer normalization [1] on the inputs.\n\n    Computes\n\n    .. math::\n\n        y = \\frac{x - E[x]}{\\sqrt{Var[x]} + \\epsilon} \\gamma + \\beta,\n\n    where :math:`\\gamma` and :math:`\\beta` are learned per feature dimension\n    parameters initialized at 1 and 0 respectively.\n\n    [1]: https://arxiv.org/abs/1607.06450\n\n    Args:\n        dims (int): The feature dimension of the input to normalize over\n        eps (float): A small additive constant for numerical stability\n        affine (bool): If True learn an affine transform to apply after the\n            normalization\n        bias (bool): If True include a translation to the affine\n            transformation. If set to False the transformation is not really affine\n            just scaling.\n    \"\"\"\n    def __init__(\n        self, dims: int, eps: float = ..., affine: bool = ..., bias: bool = ...\n    ) -> None: ...\n    def __call__(self, x) -> mx.array: ...\n\nclass RMSNorm(Module):\n    r\"\"\"Applies Root Mean Square normalization [1] to the inputs.\n\n    Computes\n\n    ..  math::\n\n        y = \\frac{x}{\\sqrt{E[x^2] + \\epsilon}} \\gamma\n\n    where :math:`\\gamma` is a learned per feature dimension parameter initialized at\n    1.\n\n    Note the accumulation for the mean is done in 32-bit precision.\n\n    [1]: https://arxiv.org/abs/1910.07467\n\n    Args:\n        dims (int): The feature dimension of the input to normalize over\n        eps (float): A small additive constant for numerical stability\n    \"\"\"\n\n    weight: mx.array\n\n    def __init__(self, dims: int, eps: float = ...) -> None: ...\n    def __call__(self, x) -> mx.array: ...\n\nclass GroupNorm(Module):\n    r\"\"\"Applies Group Normalization [1] to the inputs.\n\n    Computes the same normalization as layer norm, namely\n\n    .. math::\n\n        y = \\frac{x - E[x]}{\\sqrt{Var[x]} + \\epsilon} \\gamma + \\beta,\n\n    where :math:`\\gamma` and :math:`\\beta` are learned per feature dimension\n    parameters initialized at 1 and 0 respectively. However, the mean and\n    variance are computed over the spatial dimensions and each group of\n    features. In particular, the input is split into num_groups across the\n    feature dimension.\n\n    The feature dimension is assumed to be the last dimension and the dimensions\n    that precede it (except the first) are considered the spatial dimensions.\n\n    [1]: https://arxiv.org/abs/1803.08494\n\n    Args:\n        num_groups (int): Number of groups to separate the features into\n        dims (int): The feature dimensions of the input to normalize over\n        eps (float): A small additive constant for numerical stability\n        affine (bool): If True learn an affine transform to apply after the\n            normalization.\n        pytorch_compatible (bool): If True perform the group normalization in\n            the same order/grouping as PyTorch.\n    \"\"\"\n    def __init__(\n        self,\n        num_groups: int,\n        dims: int,\n        eps: float = ...,\n        affine: bool = ...,\n        pytorch_compatible: bool = ...,\n    ) -> None: ...\n    def __call__(self, x) -> mx.array: ...\n\nclass BatchNorm(Module):\n    r\"\"\"Applies Batch Normalization over a 2D or 3D input.\n\n    Computes\n\n    .. math::\n\n        y = \\frac{x - E[x]}{\\sqrt{Var[x]} + \\epsilon} \\gamma + \\beta,\n\n    where :math:`\\gamma` and :math:`\\beta` are learned per feature dimension\n    parameters initialized at 1 and 0 respectively.\n\n    The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the\n    batch, ``C`` is the number of features or channels, and ``L`` is the\n    sequence length. The output has the same shape as the input. For\n    four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are\n    the height and width respectively.\n\n    For more information on Batch Normalization, see the original paper `Batch\n    Normalization: Accelerating Deep Network Training by Reducing Internal\n    Covariate Shift <https://arxiv.org/abs/1502.03167>`_.\n\n    Args:\n        num_features (int): The feature dimension to normalize over.\n        eps (float, optional): A small additive constant for numerical\n            stability. Default: ``1e-5``.\n        momentum (float, optional): The momentum for updating the running\n            mean and variance. Default: ``0.1``.\n        affine (bool, optional): If ``True``, apply a learned affine\n            transformation after the normalization. Default: ``True``.\n        track_running_stats (bool, optional): If ``True``, track the\n            running mean and variance. Default: ``True``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn as nn\n        >>> x = mx.random.normal((5, 4))\n        >>> bn = nn.BatchNorm(num_features=4, affine=True)\n        >>> output = bn(x)\n    \"\"\"\n    def __init__(\n        self,\n        num_features: int,\n        eps: float = ...,\n        momentum: float = ...,\n        affine: bool = ...,\n        track_running_stats: bool = ...,\n    ) -> None: ...\n    def unfreeze(self, *args, **kwargs):  # -> None:\n        \"\"\"Wrap unfreeze to make sure that running_mean and var are always\n        frozen parameters.\"\"\"\n\n    def __call__(self, x: mx.array) -> mx.array:\n        \"\"\"\n        Forward pass of BatchNorm.\n\n        Args:\n            x (array): Input tensor.\n\n        Returns:\n            array: Normalized output tensor.\n        \"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/pooling.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Optional, Tuple, Union\n\nimport mlx.core as mx\nfrom base import Module\n\nclass _Pool(Module):\n    def __init__(\n        self, pooling_function, kernel_size, stride, padding, padding_value\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass _Pool1d(_Pool):\n    def __init__(\n        self,\n        pooling_function,\n        padding_value,\n        kernel_size: Union[int, Tuple[int]],\n        stride: Optional[Union[int, Tuple[int]]] = ...,\n        padding: Union[int, Tuple[int]] = ...,\n    ) -> None: ...\n\nclass _Pool2d(_Pool):\n    def __init__(\n        self,\n        pooling_function,\n        padding_value,\n        kernel_size: Union[int, Tuple[int, int]],\n        stride: Optional[Union[int, Tuple[int, int]]] = ...,\n        padding: Optional[Union[int, Tuple[int, int]]] = ...,\n    ) -> None: ...\n\nclass _Pool3d(_Pool):\n    def __init__(\n        self,\n        pooling_function,\n        padding_value,\n        kernel_size: Union[int, Tuple[int, int, int]],\n        stride: Optional[Union[int, Tuple[int, int, int]]] = ...,\n        padding: Optional[Union[int, Tuple[int, int, int]]] = ...,\n    ) -> None: ...\n\nclass MaxPool1d(_Pool1d):\n    r\"\"\"Applies 1-dimensional max pooling.\n\n    Spatially downsamples the input by taking the maximum of a sliding window\n    of size ``kernel_size`` and sliding stride ``stride``.\n\n    Args:\n        kernel_size (int or tuple(int)): The size of the pooling window kernel.\n        stride (int or tuple(int), optional): The stride of the pooling window.\n            Default: ``kernel_size``.\n        padding (int or tuple(int), optional): How much negative infinity\n            padding to apply to the input. The padding amount is applied to\n            both sides of the spatial axis. Default: ``0``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import layers as nn\n        >>> x = mx.random.normal(shape=(4, 16, 5))\n        >>> pool = nn.MaxPool1d(kernel_size=2, stride=2)\n        >>> pool(x)\n    \"\"\"\n    def __init__(\n        self,\n        kernel_size: Union[int, Tuple[int]],\n        stride: Optional[Union[int, Tuple[int]]] = ...,\n        padding: Union[int, Tuple[int]] = ...,\n    ) -> None: ...\n\nclass AvgPool1d(_Pool1d):\n    r\"\"\"Applies 1-dimensional average pooling.\n\n    Spatially downsamples the input by taking the average of a sliding window\n    of size ``kernel_size`` and sliding stride ``stride``.\n\n    Args:\n        kernel_size (int or tuple(int)): The size of the pooling window kernel.\n        stride (int or tuple(int), optional): The stride of the pooling window.\n            Default: ``kernel_size``.\n        padding (int or tuple(int), optional): How much zero padding to apply to\n            the input. The padding amount is applied to both sides of the spatial\n            axis. Default: ``0``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import layers as nn\n        >>> x = mx.random.normal(shape=(4, 16, 5))\n        >>> pool = nn.AvgPool1d(kernel_size=2, stride=2)\n        >>> pool(x)\n    \"\"\"\n    def __init__(\n        self,\n        kernel_size: Union[int, Tuple[int]],\n        stride: Optional[Union[int, Tuple[int]]] = ...,\n        padding: Union[int, Tuple[int]] = ...,\n    ) -> None: ...\n\nclass MaxPool2d(_Pool2d):\n    r\"\"\"Applies 2-dimensional max pooling.\n\n    Spatially downsamples the input by taking the maximum of a sliding window\n    of size ``kernel_size`` and sliding stride ``stride``.\n\n    The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:\n\n    * a single ``int`` -- in which case the same value is used for both the\n      height and width axis.\n    * a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is\n      used for the height axis, the second ``int`` for the width axis.\n\n    Args:\n        kernel_size (int or tuple(int, int)): The size of the pooling window.\n        stride (int or tuple(int, int), optional): The stride of the pooling\n            window. Default: ``kernel_size``.\n        padding (int or tuple(int, int), optional): How much negative infinity\n            padding to apply to the input. The padding is applied on both sides\n            of the height and width axis. Default: ``0``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import layers as nn\n        >>> x = mx.random.normal(shape=(8, 32, 32, 4))\n        >>> pool = nn.MaxPool2d(kernel_size=2, stride=2)\n        >>> pool(x)\n    \"\"\"\n    def __init__(\n        self,\n        kernel_size: Union[int, Tuple[int, int]],\n        stride: Optional[Union[int, Tuple[int, int]]] = ...,\n        padding: Optional[Union[int, Tuple[int, int]]] = ...,\n    ) -> None: ...\n\nclass AvgPool2d(_Pool2d):\n    r\"\"\"Applies 2-dimensional average pooling.\n\n    Spatially downsamples the input by taking the average of a sliding window\n    of size ``kernel_size`` and sliding stride ``stride``.\n\n    The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:\n\n    * a single ``int`` -- in which case the same value is used for both the\n      height and width axis.\n    * a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is\n      used for the height axis, the second ``int`` for the width axis.\n\n    Args:\n        kernel_size (int or tuple(int, int)): The size of the pooling window.\n        stride (int or tuple(int, int), optional): The stride of the pooling\n            window. Default: ``kernel_size``.\n        padding (int or tuple(int, int), optional): How much zero\n            padding to apply to the input. The padding is applied on both sides\n            of the height and width axis. Default: ``0``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import layers as nn\n        >>> x = mx.random.normal(shape=(8, 32, 32, 4))\n        >>> pool = nn.AvgPool2d(kernel_size=2, stride=2)\n        >>> pool(x)\n    \"\"\"\n    def __init__(\n        self,\n        kernel_size: Union[int, Tuple[int, int]],\n        stride: Optional[Union[int, Tuple[int, int]]] = ...,\n        padding: Optional[Union[int, Tuple[int, int]]] = ...,\n    ) -> None: ...\n\nclass MaxPool3d(_Pool3d):\n    r\"\"\"Applies 3-dimensional max pooling.\n\n    Spatially downsamples the input by taking the maximum of a sliding window\n    of size ``kernel_size`` and sliding stride ``stride``.\n\n    The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:\n\n    * a single ``int`` -- in which case the same value is used for the depth,\n      height, and width axis.\n    * a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used\n      for the depth axis, the second ``int`` for the height axis, and the third\n      ``int`` for the width axis.\n\n    Args:\n        kernel_size (int or tuple(int, int, int)): The size of the pooling window.\n        stride (int or tuple(int, int, int), optional): The stride of the pooling\n            window. Default: ``kernel_size``.\n        padding (int or tuple(int, int, int), optional): How much negative infinity\n            padding to apply to the input. The padding is applied on both sides\n            of the depth, height and width axis. Default: ``0``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import layers as nn\n        >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))\n        >>> pool = nn.MaxPool3d(kernel_size=2, stride=2)\n        >>> pool(x)\n    \"\"\"\n    def __init__(\n        self,\n        kernel_size: Union[int, Tuple[int, int, int]],\n        stride: Optional[Union[int, Tuple[int, int, int]]] = ...,\n        padding: Optional[Union[int, Tuple[int, int, int]]] = ...,\n    ) -> None: ...\n\nclass AvgPool3d(_Pool3d):\n    r\"\"\"Applies 3-dimensional average pooling.\n\n    Spatially downsamples the input by taking the average of a sliding window\n    of size ``kernel_size`` and sliding stride ``stride``.\n\n    The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:\n\n    * a single ``int`` -- in which case the same value is used for the depth,\n      height, and width axis.\n    * a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used\n      for the depth axis, the second ``int`` for the height axis, and the third\n      ``int`` for the width axis.\n\n    Args:\n        kernel_size (int or tuple(int, int, int)): The size of the pooling window.\n        stride (int or tuple(int, int, int), optional): The stride of the pooling\n            window. Default: ``kernel_size``.\n        padding (int or tuple(int, int, int), optional): How much zero\n            padding to apply to the input. The padding is applied on both sides\n            of the depth, height and width axis. Default: ``0``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import layers as nn\n        >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))\n        >>> pool = nn.AvgPool3d(kernel_size=2, stride=2)\n        >>> pool(x)\n    \"\"\"\n    def __init__(\n        self,\n        kernel_size: Union[int, Tuple[int, int, int]],\n        stride: Optional[Union[int, Tuple[int, int, int]]] = ...,\n        padding: Optional[Union[int, Tuple[int, int, int]]] = ...,\n    ) -> None: ...\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/positional_encoding.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Optional\n\nimport mlx.core as mx\nfrom base import Module\n\nclass RoPE(Module):\n    \"\"\"Implements the rotary positional encoding.\n\n    The traditional implementation rotates consecutive pairs of elements in the\n    feature dimension while the default implementation rotates pairs with\n    stride half the feature dimensions for efficiency.\n\n    For more details see `RoFormer: Enhanced Transformer with Rotary Position\n    Embedding <https://arxiv.org/abs/2104.09864>`_.\n\n    Args:\n        dims (int): The feature dimensions to be rotated. If the input feature\n            is larger than dims then the rest is left unchanged.\n        traditional (bool, optional): If set to ``True`` choose the traditional\n            implementation which is slightly less efficient. Default: ``False``.\n        base (float, optional): The base used to compute angular frequency for\n            each dimension in the positional encodings. Default: ``10000``.\n        scale (float, optional): The scale used to scale the positions. Default: ``1.0``.\n    \"\"\"\n    def __init__(\n        self, dims: int, traditional: bool = ..., base: float = ..., scale: float = ...\n    ) -> None: ...\n    def __call__(self, x, offset: int = ...) -> mx.array: ...\n\nclass SinusoidalPositionalEncoding(Module):\n    r\"\"\"Implements sinusoidal positional encoding.\n\n    For more details see the paper `Attention Is All You Need\n    <https://arxiv.org/abs/1706.03762>`_.\n\n    Args:\n        dims (int): The dimensionality of the resulting positional embeddings.\n        min_freq (float, optional): The minimum frequency expected. Default:\n            ``0.0001``.\n        max_freq (float, optional): The maximum frequency expected. Default:\n            ``1``.\n        scale (float, optional): A multiplicative scale for the embeddings.\n            Default: ``sqrt(2/dims)``.\n        cos_first (bool, optional): If ``True`` embed using ``[cos(x); sin(x)]``\n            instead of the reverse. Default: ``False``.\n        full_turns (bool, optional): If ``True`` multiply the frequencies with\n            :math:`2\\pi`. Default: ``False``.\n    \"\"\"\n    def __init__(\n        self,\n        dims: int,\n        min_freq: float = ...,\n        max_freq: float = ...,\n        scale: Optional[float] = ...,\n        cos_first: bool = ...,\n        full_turns: bool = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass ALiBi(Module):\n    _alibi_mask_key = ...\n    _alibi_mask = ...\n    @classmethod\n    def create_alibi_matrix(\n        cls,\n        q_sequence_length: int,\n        k_sequence_length: int,\n        num_heads: int,\n        offset: int,\n        dtype=...,\n    ) -> mx.array | None: ...\n    @staticmethod\n    def create_alibi_slope(num_heads: int) -> mx.array: ...\n    def __call__(\n        self, attention_scores: mx.array, offset=..., mask=...\n    ) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/quantized.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Callable, Optional, Union\n\nimport mlx.core as mx\nfrom base import Module\n\ndef quantize(\n    model: Module,\n    group_size: int = ...,\n    bits: int = ...,\n    *,\n    mode: str = ...,\n    class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = ...,\n):  # -> None:\n    \"\"\"Quantize the sub-modules of a module according to a predicate.\n\n    By default all layers that define a ``to_quantized(group_size, bits)``\n    method will be quantized. Both :obj:`Linear` and :obj:`Embedding` layers\n    will be quantized. Note also, the module is updated in-place.\n\n    Args:\n        model (Module): The model whose leaf modules may be quantized.\n        group_size (int): The quantization group size (see\n           :func:`mlx.core.quantize`). Default: ``64``.\n        bits (int): The number of bits per parameter (see\n           :func:`mlx.core.quantize`). Default: ``4``.\n        mode (str): The quantization method to use (see\n           :func:`mlx.core.quantize`). Default: ``\"affine\"``.\n        class_predicate (Optional[Callable]): A callable which receives the\n          :obj:`Module` path and :obj:`Module` itself and returns ``True`` or a\n          dict of params for `to_quantized` if it should be quantized and\n          ``False`` otherwise. If ``None``, then all layers that define a\n          ``to_quantized(group_size, bits)`` method are quantized.\n          Default: ``None``.\n    \"\"\"\n\nclass QuantizedEmbedding(Module):\n    \"\"\"The same as :obj:`Embedding` but with a  quantized weight matrix.\n\n    :obj:`QuantizedEmbedding` also provides a :meth:`from_embedding`\n    classmethod to convert embedding layers to :obj:`QuantizedEmbedding`\n    layers.\n\n    Args:\n        num_embeddings (int): How many possible discrete tokens can we embed.\n           Usually called the vocabulary size.\n        dims (int): The dimensionality of the embeddings.\n        group_size (int, optional): The group size to use for the quantized\n            weight. See :func:`~mlx.core.quantize`. Default: ``64``.\n        bits (int, optional): The bit width to use for the quantized weight.\n            See :func:`~mlx.core.quantize`. Default: ``4``.\n        mode (str): The quantization method to use (see\n           :func:`mlx.core.quantize`). Default: ``\"affine\"``.\n    \"\"\"\n    def __init__(\n        self,\n        num_embeddings: int,\n        dims: int,\n        group_size: int = ...,\n        bits: int = ...,\n        mode: str = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n    def as_linear(self, x: mx.array) -> mx.array:\n        \"\"\"\n        Call the quantized embedding layer as a quantized linear layer.\n\n        Use this for example when input embedding and output projection\n        weights are tied.\n        \"\"\"\n\n    @classmethod\n    def from_embedding(\n        cls,\n        embedding_layer: Module,\n        group_size: int = ...,\n        bits: int = ...,\n        mode: str = ...,\n    ) -> QuantizedEmbedding:\n        \"\"\"Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.\"\"\"\n\nclass QuantizedLinear(Module):\n    \"\"\"Applies an affine transformation to the input using a quantized weight matrix.\n\n    It is the quantized equivalent of :class:`Linear`. For now its\n    parameters are frozen and will not be included in any gradient computation\n    but this will probably change in the future.\n\n    :obj:`QuantizedLinear` also provides a classmethod :meth:`from_linear` to\n    convert linear layers to :obj:`QuantizedLinear` layers.\n\n    Args:\n        input_dims (int): The dimensionality of the input features.\n        output_dims (int): The dimensionality of the output features.\n        bias (bool, optional): If set to ``False`` then the layer will not use\n            a bias. Default: ``True``.\n        group_size (int, optional): The group size to use for the quantized\n            weight. See :func:`~mlx.core.quantize`. Default: ``64``.\n        bits (int, optional): The bit width to use for the quantized weight.\n            See :func:`~mlx.core.quantize`. Default: ``4``.\n        mode (str): The quantization method to use (see\n           :func:`mlx.core.quantize`). Default: ``\"affine\"``.\n    \"\"\"\n    def __init__(\n        self,\n        input_dims: int,\n        output_dims: int,\n        bias: bool = ...,\n        group_size: int = ...,\n        bits: int = ...,\n        mode: str = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n    @classmethod\n    def from_linear(\n        cls,\n        linear_layer: Module,\n        group_size: int = ...,\n        bits: int = ...,\n        mode: str = ...,\n    ) -> QuantizedLinear:\n        \"\"\"Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.\"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/recurrent.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Callable, Optional\n\nimport mlx.core as mx\nfrom base import Module\n\nclass RNN(Module):\n    r\"\"\"An Elman recurrent layer.\n\n    The input is a sequence of shape ``NLD`` or ``LD`` where:\n\n    * ``N`` is the optional batch dimension\n    * ``L`` is the sequence length\n    * ``D`` is the input's feature dimension\n\n    Concretely, for each element along the sequence length axis, this\n    layer applies the function:\n\n    .. math::\n\n        h_{t + 1} = \\text{tanh} (W_{ih}x_t + W_{hh}h_t + b)\n\n    The hidden state :math:`h` has shape ``NH`` or ``H``, depending on\n    whether the input is batched or not. Returns the hidden state at each\n    time step, of shape ``NLH`` or ``LH``.\n\n    Args:\n        input_size (int): Dimension of the input, ``D``.\n        hidden_size (int): Dimension of the hidden state, ``H``.\n        bias (bool, optional): Whether to use a bias. Default: ``True``.\n        nonlinearity (callable, optional): Non-linearity to use. If ``None``,\n            then func:`tanh` is used. Default: ``None``.\n    \"\"\"\n    def __init__(\n        self,\n        input_size: int,\n        hidden_size: int,\n        bias: bool = ...,\n        nonlinearity: Optional[Callable] = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array, hidden=...) -> mx.array: ...\n\nclass GRU(Module):\n    r\"\"\"A gated recurrent unit (GRU) RNN layer.\n\n    The input has shape ``NLD`` or ``LD`` where:\n\n    * ``N`` is the optional batch dimension\n    * ``L`` is the sequence length\n    * ``D`` is the input's feature dimension\n\n    Concretely, for each element of the sequence, this layer computes:\n\n    .. math::\n\n        \\begin{aligned}\n        r_t &= \\sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\\\\n        z_t &= \\sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\\\\n        n_t &= \\text{tanh}(W_{xn}x_t + b_{n} + r_t \\odot (W_{hn}h_t + b_{hn})) \\\\\n        h_{t + 1} &= (1 - z_t) \\odot n_t + z_t \\odot h_t\n        \\end{aligned}\n\n    The hidden state :math:`h` has shape ``NH`` or ``H`` depending on\n    whether the input is batched or not. Returns the hidden state at each\n    time step of shape ``NLH`` or ``LH``.\n\n    Args:\n        input_size (int): Dimension of the input, ``D``.\n        hidden_size (int): Dimension of the hidden state, ``H``.\n        bias (bool): Whether to use biases or not. Default: ``True``.\n    \"\"\"\n    def __init__(self, input_size: int, hidden_size: int, bias: bool = ...) -> None: ...\n    def __call__(self, x: mx.array, hidden=...) -> mx.array: ...\n\nclass LSTM(Module):\n    r\"\"\"An LSTM recurrent layer.\n\n    The input has shape ``NLD`` or ``LD`` where:\n\n    * ``N`` is the optional batch dimension\n    * ``L`` is the sequence length\n    * ``D`` is the input's feature dimension\n\n    Concretely, for each element of the sequence, this layer computes:\n\n    .. math::\n        \\begin{aligned}\n        i_t &= \\sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\\\\n        f_t &= \\sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\\\\n        g_t &= \\text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\\\\n        o_t &= \\sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\\\\n        c_{t + 1} &= f_t \\odot c_t + i_t \\odot g_t \\\\\n        h_{t + 1} &= o_t \\text{tanh}(c_{t + 1})\n        \\end{aligned}\n\n    The hidden state :math:`h` and cell state :math:`c` have shape ``NH``\n    or ``H``, depending on whether the input is batched or not.\n\n    The layer returns two arrays, the hidden state and the cell state at\n    each time step, both of shape ``NLH`` or ``LH``.\n\n    Args:\n        input_size (int): Dimension of the input, ``D``.\n        hidden_size (int): Dimension of the hidden state, ``H``.\n        bias (bool): Whether to use biases or not. Default: ``True``.\n    \"\"\"\n    def __init__(self, input_size: int, hidden_size: int, bias: bool = ...) -> None: ...\n    def __call__(\n        self, x: mx.array, hidden=..., cell=...\n    ) -> tuple[mx.array, mx.array]: ...\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/transformer.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Any, Callable, Optional\n\nimport mlx.core as mx\nfrom base import Module\n\nclass MultiHeadAttention(Module):\n    \"\"\"Implements the scaled dot product attention with multiple heads.\n\n    Given inputs for queries, keys and values the ``MultiHeadAttention``\n    produces new values by aggregating information from the input values\n    according to the similarities of the input queries and keys.\n\n    All inputs as well as the output are linearly projected without biases by\n    default.\n\n    ``MultiHeadAttention`` also takes an optional additive attention mask that\n    should be broadcastable with ``(batch, num_heads, # queries, # keys)``. The\n    mask should have ``-inf`` or very large negative numbers at the positions\n    that should *not* be attended to.\n\n    Args:\n        dims (int): The model dimensions. This is also the default\n            value for the queries, keys, values, and the output.\n        num_heads (int): The number of attention heads to use.\n        query_input_dims (int, optional): The input dimensions of the queries.\n            Default: ``dims``.\n        key_input_dims (int, optional): The input dimensions of the keys.\n            Default: ``dims``.\n        value_input_dims (int, optional): The input dimensions of the values.\n            Default: ``key_input_dims``.\n        value_dims (int, optional): The dimensions of the values after the\n            projection. Default: ``dims``.\n        value_output_dims (int, optional): The dimensions the new values will\n            be projected to. Default: ``dims``.\n        bias (bool, optional): Whether or not to use a bias in the projections.\n            Default: ``False``.\n    \"\"\"\n    def __init__(\n        self,\n        dims: int,\n        num_heads: int,\n        query_input_dims: Optional[int] = ...,\n        key_input_dims: Optional[int] = ...,\n        value_input_dims: Optional[int] = ...,\n        value_dims: Optional[int] = ...,\n        value_output_dims: Optional[int] = ...,\n        bias: bool = ...,\n    ) -> None: ...\n    def __call__(\n        self, queries: mx.array, keys: mx.array, values: mx.array, mask: mx.array = ...\n    ) -> mx.array: ...\n    @staticmethod\n    def create_additive_causal_mask(N: int, dtype: mx.Dtype = ...) -> mx.array: ...\n\nclass TransformerEncoderLayer(Module):\n    def __init__(\n        self,\n        dims: int,\n        num_heads: int,\n        mlp_dims: Optional[int] = ...,\n        dropout: float = ...,\n        activation: Callable[[Any], Any] = ...,\n        norm_first: bool = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array, mask: mx.array) -> mx.array: ...\n\nclass TransformerEncoder(Module):\n    def __init__(\n        self,\n        num_layers: int,\n        dims: int,\n        num_heads: int,\n        mlp_dims: Optional[int] = ...,\n        dropout: float = ...,\n        activation=...,\n        norm_first: bool = ...,\n        checkpoint: bool = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array, mask: mx.array) -> mx.array: ...\n\nclass TransformerDecoderLayer(Module):\n    def __init__(\n        self,\n        dims: int,\n        num_heads: int,\n        mlp_dims: Optional[int] = ...,\n        dropout: float = ...,\n        activation: Callable[[Any], Any] = ...,\n        norm_first: bool = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array, memory, x_mask, memory_mask) -> mx.array: ...\n\nclass TransformerDecoder(Module):\n    def __init__(\n        self,\n        num_layers: int,\n        dims: int,\n        num_heads: int,\n        mlp_dims: Optional[int] = ...,\n        dropout: float = ...,\n        activation=...,\n        norm_first: bool = ...,\n        checkpoint: bool = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array, memory, x_mask, memory_mask) -> mx.array: ...\n\nclass Transformer(Module):\n    \"\"\"\n    Implements a standard Transformer model.\n\n    The implementation is based on `Attention Is All You Need\n    <https://arxiv.org/abs/1706.03762>`_.\n\n    The Transformer model contains an encoder and a decoder. The encoder\n    processes the input sequence and the decoder generates the output sequence.\n    The interaction between encoder and decoder happens through the attention\n    mechanism.\n\n    Args:\n        dims (int, optional): The number of expected features in the\n            encoder/decoder inputs. Default: ``512``.\n        num_heads (int, optional): The number of attention heads. Default:\n            ``8``.\n        num_encoder_layers (int, optional): The number of encoder layers in the\n            Transformer encoder. Default: ``6``.\n        num_decoder_layers (int, optional): The number of decoder layers in the\n            Transformer decoder. Default: ``6``.\n        mlp_dims (int, optional): The hidden dimension of the MLP block in each\n            Transformer layer. Defaults to ``4*dims`` if not provided. Default:\n            ``None``.\n        dropout (float, optional): The dropout value for the Transformer\n            encoder and decoder. Dropout is used after each attention layer and\n            the activation in the MLP layer. Default: ``0.0``.\n        activation (function, optional): the activation function for the MLP\n            hidden layer. Default: :func:`relu`.\n        custom_encoder (nn.Module, optional): A custom encoder to replace the\n            standard Transformer encoder. Default: ``None``.\n        custom_decoder (nn.Module, optional): A custom decoder to replace the\n            standard Transformer decoder. Default: ``None``.\n        norm_first (bool, optional): if ``True``, encoder and decoder layers\n            will perform layer normalization before attention and MLP\n            operations, otherwise after. Default: ``True``.\n        checkpoint (bool, optional): if ``True`` perform gradient checkpointing\n            to reduce the memory usage at the expense of more computation.\n            Default: ``False``.\n    \"\"\"\n    def __init__(\n        self,\n        dims: int = ...,\n        num_heads: int = ...,\n        num_encoder_layers: int = ...,\n        num_decoder_layers: int = ...,\n        mlp_dims: Optional[int] = ...,\n        dropout: float = ...,\n        activation: Callable[[Any], Any] = ...,\n        custom_encoder: Optional[Any] = ...,\n        custom_decoder: Optional[Any] = ...,\n        norm_first: bool = ...,\n        checkpoint: bool = ...,\n    ) -> None: ...\n    def __call__(\n        self, src, tgt, src_mask, tgt_mask, memory_mask\n    ) -> mx.array:  # -> array | Any:\n        ...\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/layers/upsample.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Literal, Tuple, Union\n\nimport mlx.core as mx\nfrom base import Module\n\ndef upsample_nearest(x: mx.array, scale_factor: Tuple) -> mx.array: ...\ndef upsample_linear(\n    x: mx.array, scale_factor: Tuple, align_corners: bool = ...\n):  # -> int:\n    ...\ndef upsample_cubic(\n    x: mx.array, scale_factor: Tuple, align_corners: bool = ...\n):  # -> int:\n    ...\n\nclass Upsample(Module):\n    r\"\"\"Upsample the input signal spatially.\n\n    The spatial dimensions are by convention dimensions ``1`` to ``x.ndim -\n    2``. The first is the batch dimension and the last is the feature\n    dimension.\n\n    For example, an audio signal would be 3D with 1 spatial dimension, an image\n    4D with 2 and so on and so forth.\n\n    There are three upsampling algorithms implemented nearest neighbor upsampling,\n    linear interpolation, and cubic interpolation. All can be applied to any number\n    of spatial dimensions. The linear interpolation will be bilinear, trilinear etc\n    when applied to more than one spatial dimension. And cubic interpolation will be\n    bicubic when there are 2 spatial dimensions.\n\n    .. note::\n       When using one of the linear or cubic interpolation modes the ``align_corners``\n       argument changes how the corners are treated in the input image. If\n       ``align_corners=True`` then the top and left edge of the input and\n       output will be matching as will the bottom right edge.\n\n    Parameters:\n        scale_factor (float or tuple): The multiplier for the spatial size.\n            If a ``float`` is provided, it is the multiplier for all spatial dimensions.\n            Otherwise, the number of scale factors provided must match the\n            number of spatial dimensions.\n        mode (str, optional): The upsampling algorithm, either ``\"nearest\"``,\n            ``\"linear\"`` or ``\"cubic\"``. Default: ``\"nearest\"``.\n        align_corners (bool, optional): Changes the way the corners are treated\n            during ``\"linear\"`` and ``\"cubic\"`` upsampling.  See the note above and the\n            examples below for more details.  Default: ``False``.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn as nn\n        >>> x = mx.arange(1, 5).reshape((1, 2, 2, 1))\n        >>> x\n        array([[[[1],\n                 [2]],\n                [[3],\n                 [4]]]], dtype=int32)\n        >>> n = nn.Upsample(scale_factor=2, mode='nearest')\n        >>> n(x).squeeze()\n        array([[1, 1, 2, 2],\n               [1, 1, 2, 2],\n               [3, 3, 4, 4],\n               [3, 3, 4, 4]], dtype=int32)\n        >>> b = nn.Upsample(scale_factor=2, mode='linear')\n        >>> b(x).squeeze()\n        array([[1, 1.25, 1.75, 2],\n               [1.5, 1.75, 2.25, 2.5],\n               [2.5, 2.75, 3.25, 3.5],\n               [3, 3.25, 3.75, 4]], dtype=float32)\n        >>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)\n        >>> b(x).squeeze()\n        array([[1, 1.33333, 1.66667, 2],\n               [1.66667, 2, 2.33333, 2.66667],\n               [2.33333, 2.66667, 3, 3.33333],\n               [3, 3.33333, 3.66667, 4]], dtype=float32)\n    \"\"\"\n    def __init__(\n        self,\n        scale_factor: Union[float, Tuple],\n        mode: Literal[\"nearest\", \"linear\", \"cubic\"] = ...,\n        align_corners: bool = ...,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/losses.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Literal, Optional\n\nimport mlx.core as mx\n\nReduction = Literal[\"none\", \"mean\", \"sum\"]\n\ndef cross_entropy(\n    logits: mx.array,\n    targets: mx.array,\n    weights: Optional[mx.array] = ...,\n    axis: int = ...,\n    label_smoothing: float = ...,\n    reduction: Reduction = ...,\n) -> mx.array:\n    \"\"\"\n    Computes the cross entropy loss.\n\n    Args:\n        logits (array): The unnormalized logits.\n        targets (array): The ground truth values. These can be class indices or\n            probabilities for each class. If the ``targets`` are class indices,\n            then ``targets`` shape should match the ``logits`` shape with\n            the ``axis`` dimension removed. If the ``targets`` are probabilities\n            (or one-hot encoded), then the ``targets`` shape should be the same as\n            the ``logits`` shape.\n        weights (array, optional): Optional weights for each target. Default: ``None``.\n        axis (int, optional): The axis over which to compute softmax. Default: ``-1``.\n        label_smoothing (float, optional): Label smoothing factor. Default: ``0``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed cross entropy loss.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn as nn\n        >>>\n        >>> # Class indices as targets\n        >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])\n        >>> targets = mx.array([0, 1])\n        >>> nn.losses.cross_entropy(logits, targets)\n        array([0.0485873, 0.0485873], dtype=float32)\n        >>>\n        >>> # Probabilities (or one-hot vectors) as targets\n        >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])\n        >>> targets = mx.array([[0.9, 0.1], [0.1, 0.9]])\n        >>> nn.losses.cross_entropy(logits, targets)\n        array([0.348587, 0.348587], dtype=float32)\n    \"\"\"\n\ndef binary_cross_entropy(\n    inputs: mx.array,\n    targets: mx.array,\n    weights: Optional[mx.array] = ...,\n    with_logits: bool = ...,\n    reduction: Reduction = ...,\n) -> mx.array:\n    \"\"\"\n    Computes the binary cross entropy loss.\n\n    By default, this function takes the pre-sigmoid logits, which results in a faster\n    and more precise loss. For improved numerical stability when ``with_logits=False``,\n    the loss calculation clips the input probabilities (in log-space) to a minimum value\n    of ``-100``.\n\n    Args:\n        inputs (array): The predicted values. If ``with_logits`` is ``True``, then\n            ``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities.\n        targets (array): The binary target values in {0, 1}.\n        with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``.\n        weights (array, optional): Optional weights for each target. Default: ``None``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.\n\n    Returns:\n        array: The computed binary cross entropy loss.\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn as nn\n\n        >>> logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291])\n        >>> targets = mx.array([0, 0, 1, 1])\n        >>> loss = nn.losses.binary_cross_entropy(logits, targets, reduction=\"mean\")\n        >>> loss\n        array(0.539245, dtype=float32)\n\n        >>> probs = mx.array([0.1, 0.1, 0.4, 0.4])\n        >>> targets = mx.array([0, 0, 1, 1])\n        >>> loss = nn.losses.binary_cross_entropy(probs, targets, with_logits=False, reduction=\"mean\")\n        >>> loss\n        array(0.510826, dtype=float32)\n    \"\"\"\n\ndef l1_loss(\n    predictions: mx.array, targets: mx.array, reduction: Reduction = ...\n) -> mx.array:\n    \"\"\"\n    Computes the L1 loss.\n\n    Args:\n        predictions (array): The predicted values.\n        targets (array): The target values.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.\n\n    Returns:\n        array: The computed L1 loss.\n    \"\"\"\n\ndef mse_loss(\n    predictions: mx.array, targets: mx.array, reduction: Reduction = ...\n) -> mx.array:\n    \"\"\"\n    Computes the mean squared error loss.\n\n    Args:\n        predictions (array): The predicted values.\n        targets (array): The target values.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.\n\n    Returns:\n        array: The computed mean squared error loss.\n    \"\"\"\n\ndef nll_loss(\n    inputs: mx.array, targets: mx.array, axis: int = ..., reduction: Reduction = ...\n) -> mx.array:\n    \"\"\"\n    Computes the negative log likelihood loss.\n\n    Args:\n        inputs (array): The predicted distribution in log space.\n        targets (array): The target values.\n        axis (int, optional): The distribution axis. Default: ``-1``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed NLL loss.\n    \"\"\"\n\ndef gaussian_nll_loss(\n    inputs: mx.array,\n    targets: mx.array,\n    vars: mx.array,\n    full: bool = ...,\n    eps: float = ...,\n    reduction: Reduction = ...,\n) -> mx.array:\n    r\"\"\"\n    Computes the negative log likelihood loss for a Gaussian distribution.\n\n    The loss is given by:\n\n    .. math::\n        \\frac{1}{2}\\left(\\log\\left(\\max\\left(\\text{vars},\n        \\ \\epsilon\\right)\\right) + \\frac{\\left(\\text{inputs} - \\text{targets} \\right)^2}\n        {\\max\\left(\\text{vars}, \\ \\epsilon \\right)}\\right) + \\text{const.}\n\n    where ``inputs`` are the predicted means and ``vars`` are the the\n    predicted variances.\n\n    Args:\n        inputs (array): The predicted expectation of the Gaussian distribution.\n        targets (array): The target values (samples from the Gaussian distribution).\n        vars (array): The predicted variance of the Gaussian distribution.\n        full (bool, optional): Whether to include the constant term in the loss calculation.\n            Default: ``False``.\n        eps (float, optional): Small positive constant for numerical stability.\n            Default: ``1e-6``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The Gaussian NLL loss.\n    \"\"\"\n\ndef kl_div_loss(\n    inputs: mx.array, targets: mx.array, axis: int = ..., reduction: Reduction = ...\n) -> mx.array:\n    \"\"\"\n    Computes the Kullback-Leibler divergence loss.\n\n    Computes the following when ``reduction == 'none'``:\n\n    .. code-block:: python\n\n        mx.exp(targets) * (targets - inputs).sum(axis)\n\n    Args:\n        inputs (array): Log probabilities for the predicted distribution.\n        targets (array): Log probabilities for the target distribution.\n        axis (int, optional): The distribution axis. Default: ``-1``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed Kullback-Leibler divergence loss.\n    \"\"\"\n\ndef smooth_l1_loss(\n    predictions: mx.array,\n    targets: mx.array,\n    beta: float = ...,\n    reduction: Reduction = ...,\n) -> mx.array:\n    r\"\"\"\n    Computes the smooth L1 loss.\n\n    The smooth L1 loss is a variant of the L1 loss which replaces the absolute\n    difference with a squared difference when the absolute difference is less\n    than ``beta``.\n\n    The formula for the smooth L1 Loss is:\n\n    .. math::\n\n      l = \\begin{cases}\n            0.5 (x - y)^2 / \\beta, & \\text{if } |x - y| < \\beta \\\\\n            |x - y| - 0.5 \\beta, & \\text{otherwise}\n          \\end{cases}\n\n    Args:\n        predictions (array): Predicted values.\n        targets (array): Ground truth values.\n        beta (float, optional): The threshold after which the loss changes\n          from the squared to the absolute difference. Default: ``1.0``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.\n\n    Returns:\n        array: The computed smooth L1 loss.\n    \"\"\"\n\ndef triplet_loss(\n    anchors: mx.array,\n    positives: mx.array,\n    negatives: mx.array,\n    axis: int = ...,\n    p: int = ...,\n    margin: float = ...,\n    eps: float = ...,\n    reduction: Reduction = ...,\n) -> mx.array:\n    r\"\"\"\n    Computes the triplet loss for a set of anchor, positive, and negative samples.\n    Margin is represented with alpha in the math section.\n\n    .. math::\n\n       \\max\\left(\\|A - P\\|_p - \\|A - N\\|_p + \\alpha, 0\\right)\n\n    Args:\n        anchors (array): The anchor samples.\n        positives (array): The positive samples.\n        negatives (array): The negative samples.\n        axis (int, optional): The distribution axis. Default: ``-1``.\n        p (int, optional): The norm degree for pairwise distance. Default: ``2``.\n        margin (float, optional): Margin for the triplet loss. Defaults to ``1.0``.\n        eps (float, optional): Small positive constant to prevent numerical instability. Defaults to ``1e-6``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: Computed triplet loss. If reduction is \"none\", returns a tensor of the same shape as input;\n                  if reduction is \"mean\" or \"sum\", returns a scalar tensor.\n    \"\"\"\n\ndef hinge_loss(\n    inputs: mx.array, targets: mx.array, reduction: Reduction = ...\n) -> mx.array:\n    r\"\"\"\n    Computes the hinge loss between inputs and targets.\n\n    .. math::\n\n       \\text{hinge}(y, y_{\\text{pred}}) = \\max(0, 1 - y \\cdot y_{\\text{pred}})\n\n\n    Args:\n        inputs (array): The predicted values.\n        targets (array): The target values. They should be -1 or 1.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed hinge loss.\n    \"\"\"\n\ndef huber_loss(\n    inputs: mx.array, targets: mx.array, delta: float = ..., reduction: Reduction = ...\n) -> mx.array:\n    r\"\"\"\n    Computes the Huber loss between inputs and targets.\n\n    .. math::\n\n        l_{\\delta}(a) =\n        \\left\\{ \\begin{array}{ll}\n            \\frac{1}{2} a^2 & \\text{for } |a| \\leq \\delta, \\\\\n            \\delta \\left( |a| - \\frac{1}{2} \\delta \\right) & \\text{otherwise.}\n        \\end{array} \\right.\n\n    Args:\n        inputs (array): The predicted values.\n        targets (array): The target values.\n        delta (float, optional): The threshold at which to change between L1 and L2 loss.\n          Default: ``1.0``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed Huber loss.\n    \"\"\"\n\ndef log_cosh_loss(\n    inputs: mx.array, targets: mx.array, reduction: Reduction = ...\n) -> mx.array:\n    r\"\"\"\n    Computes the log cosh loss between inputs and targets.\n\n    Logcosh acts like L2 loss for small errors, ensuring stable gradients,\n    and like the L1 loss for large errors, reducing sensitivity to outliers. This\n    dual behavior offers a balanced, robust approach for regression tasks.\n\n    .. math::\n\n       \\text{logcosh}(y_{\\text{true}}, y_{\\text{pred}}) =\n            \\frac{1}{n} \\sum_{i=1}^{n}\n            \\log(\\cosh(y_{\\text{pred}}^{(i)} - y_{\\text{true}}^{(i)}))\n\n\n    Args:\n        inputs (array): The predicted values.\n        targets (array): The target values.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed log cosh loss.\n    \"\"\"\n\ndef cosine_similarity_loss(\n    x1: mx.array,\n    x2: mx.array,\n    axis: int = ...,\n    eps: float = ...,\n    reduction: Reduction = ...,\n) -> mx.array:\n    r\"\"\"\n    Computes the cosine similarity between the two inputs.\n\n    The cosine similarity loss is given by\n\n    .. math::\n\n        \\frac{x_1 \\cdot x_2}{\\max(\\|x_1\\|  \\cdot \\|x_2\\|, \\epsilon)}\n\n    Args:\n        x1 (mx.array): The first set of inputs.\n        x2 (mx.array): The second set of inputs.\n        axis (int, optional): The embedding axis. Default: ``1``.\n        eps (float, optional): The minimum value of the denominator used for\n          numerical stability. Default: ``1e-8``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        mx.array: The computed cosine similarity loss.\n    \"\"\"\n\ndef margin_ranking_loss(\n    inputs1: mx.array,\n    inputs2: mx.array,\n    targets: mx.array,\n    margin: float = ...,\n    reduction: Reduction = ...,\n) -> mx.array:\n    r\"\"\"\n    Calculate the margin ranking loss that loss given inputs :math:`x_1`, :math:`x_2` and a label\n    :math:`y` (containing 1 or -1).\n\n    The loss is given by:\n\n    .. math::\n        \\text{loss} = \\max (0, -y * (x_1 - x_2) + \\text{margin})\n\n    Where :math:`y` represents ``targets``, :math:`x_1` represents ``inputs1`` and :math:`x_2`\n    represents ``inputs2``.\n\n    Args:\n        inputs1 (array): Scores for the first input.\n        inputs2 (array): Scores for the second input.\n        targets (array): Labels indicating whether samples in ``inputs1`` should be ranked higher\n            than samples in ``inputs2``. Values should be 1 or -1.\n        margin (float, optional): The margin by which the scores should be separated.\n            Default: ``0.0``.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.\n\n    Returns:\n        array: The computed margin ranking loss.\n\n    Examples:\n        >>> import mlx.core as mx\n        >>> import mlx.nn as nn\n        >>> targets = mx.array([1, 1, -1])\n        >>> inputs1 = mx.array([-0.573409, -0.765166, -0.0638])\n        >>> inputs2 = mx.array([0.75596, 0.225763, 0.256995])\n        >>> loss = nn.losses.margin_ranking_loss(inputs1, inputs2, targets)\n        >>> loss\n        array(0.773433, dtype=float32)\n    \"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx/nn/utils.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Any, Callable, Optional\n\nimport mlx.core as mx\n\nfrom .layers.base import Module\n\ndef value_and_grad(\n    model: Module, fn: Callable\n):  # -> _Wrapped[..., Any, ..., tuple[Any, Any]]:\n    \"\"\"Transform the passed function ``fn`` to a function that computes the\n    gradients of ``fn`` wrt the model's trainable parameters and also its\n    value.\n\n    Args:\n        model (Module): The model whose trainable parameters to compute\n                               gradients for\n        fn (Callable): The scalar function to compute gradients for\n\n    Returns:\n        A callable that returns the value of ``fn`` and the gradients wrt the\n        trainable parameters of ``model``\n    \"\"\"\n\ndef checkpoint(\n    module: Module, fn: Optional[Callable] = ...\n):  # -> _Wrapped[..., Any, ..., Any]:\n    \"\"\"Transform the passed callable to one that performs gradient\n    checkpointing with respect to the trainable parameters of the module (and\n    the callable's inputs).\n\n    Args:\n        module (Module): The module for whose parameters we will be\n            performing gradient checkpointing.\n        fn (Callable, optional): The function to checkpoint. If not provided it\n            defaults to the provided module.\n\n    Returns:\n        A callable that saves the inputs and outputs during the forward pass\n        and recomputes all intermediate states during the backward pass.\n    \"\"\"\n\ndef average_gradients(\n    gradients: Any,\n    group: Optional[mx.distributed.Group] = ...,\n    all_reduce_size: int = ...,\n    communication_type: Optional[mx.Dtype] = ...,\n    communication_stream: Optional[mx.Stream] = ...,\n):  # -> Any:\n    \"\"\"Average the gradients across the distributed processes in the passed group.\n\n    This helper enables concatenating several gradients of small arrays to one\n    big all reduce call for better networking performance.\n\n    Args:\n        gradients (Any): The Python tree containing the gradients (it should\n            have the same structure across processes)\n        group (Optional[mlx.core.distributed.Group]): The group of processes to\n            average the gradients. If set to ``None`` the global group is used.\n            Default: ``None``.\n        all_reduce_size (int): Group arrays until their size in bytes exceeds\n            this number. Perform one communication step per group of arrays. If\n            less or equal to 0 array grouping is disabled. Default: ``32MiB``.\n        communication_type (Optional[mlx.core.Dtype]): If provided cast to this\n            type before performing the communication. Typically cast to a\n            smaller float to reduce the communication size. Default: ``None``.\n        communication_stream (Optional[mlx.core.Stream]): The stream to usse\n            for the communication. If unspecified the default communication\n            stream is used which can vary by back-end. Default: ``None``.\n    \"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx/utils.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nfrom mlx.core import MX_ARRAY_TREE\n\ndef tree_map(\n    fn: Callable[..., Any],\n    tree: Any,\n    *rest: Any,\n    is_leaf: Callable[..., bool] | None = ...,\n) -> Any:\n    \"\"\"Applies ``fn`` to the leaves of the Python tree ``tree`` and\n    returns a new collection with the results.\n\n    If ``rest`` is provided, every item is assumed to be a superset of ``tree``\n    and the corresponding leaves are provided as extra positional arguments to\n    ``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`\n    than to :func:`map`.\n\n    The keyword argument ``is_leaf`` decides what constitutes a leaf from\n    ``tree`` similar to :func:`tree_flatten`.\n\n    .. code-block:: python\n\n        import mlx.nn as nn\n        from mlx.utils import tree_map\n\n        model = nn.Linear(10, 10)\n        print(model.parameters().keys())\n        # dict_keys(['weight', 'bias'])\n\n        # square the parameters\n        model.update(tree_map(lambda x: x*x, model.parameters()))\n\n    Args:\n        fn (callable): The function that processes the leaves of the tree.\n        tree (Any): The main Python tree that will be iterated upon.\n        rest (tuple[Any]): Extra trees to be iterated together with ``tree``.\n        is_leaf (callable, optional): An optional callable that returns ``True``\n           if the passed object is considered a leaf or ``False`` otherwise.\n\n    Returns:\n        A Python tree with the new values returned by ``fn``.\n    \"\"\"\n\ndef tree_map_with_path(\n    fn: Callable[..., Any],\n    tree: Any,\n    *rest: Any,\n    is_leaf: Callable[..., bool] | None = ...,\n    path: str | None = ...,\n) -> Any:\n    \"\"\"Applies ``fn`` to the path and leaves of the Python tree ``tree`` and\n    returns a new collection with the results.\n\n    This function is the same :func:`tree_map` but the ``fn`` takes the path as\n    the first argument followed by the remaining tree nodes.\n\n    Args:\n        fn (callable): The function that processes the leaves of the tree.\n        tree (Any): The main Python tree that will be iterated upon.\n        rest (tuple[Any]): Extra trees to be iterated together with ``tree``.\n        is_leaf (Optional[Callable]): An optional callable that returns ``True``\n           if the passed object is considered a leaf or ``False`` otherwise.\n        path (Optional[Any]): Prefix will be added to the result.\n\n    Returns:\n        A Python tree with the new values returned by ``fn``.\n\n    Example:\n        >>> from mlx.utils import tree_map_with_path\n        >>> tree = {\"model\": [{\"w\": 0, \"b\": 1}, {\"w\": 0, \"b\": 1}]}\n        >>> new_tree = tree_map_with_path(lambda path, _: print(path), tree)\n        model.0.w\n        model.0.b\n        model.1.w\n        model.1.b\n    \"\"\"\n\ndef tree_flatten(\n    tree: Any,\n    prefix: str = ...,\n    is_leaf: Callable[..., bool] | None = ...,\n    destination: list[tuple[str, Any]] | dict[str, Any] | None = ...,\n) -> list[tuple[str, Any]] | dict[str, Any]:\n    \"\"\"Flattens a Python tree to a list of key, value tuples.\n\n    The keys are using the dot notation to define trees of arbitrary depth and\n    complexity.\n\n    .. code-block:: python\n\n        from mlx.utils import tree_flatten\n\n        print(tree_flatten([[[0]]]))\n        # [(\"0.0.0\", 0)]\n\n        print(tree_flatten([[[0]]], prefix=\".hello\"))\n        # [(\"hello.0.0.0\", 0)]\n\n        tree_flatten({\"a\": {\"b\": 1}}, destination={})\n        {\"a.b\": 1}\n\n    .. note::\n       Dictionaries should have keys that are valid Python identifiers.\n\n    Args:\n        tree (Any): The Python tree to be flattened.\n        prefix (str): A prefix to use for the keys. The first character is\n            always discarded.\n        is_leaf (callable): An optional callable that returns True if the\n            passed object is considered a leaf or False otherwise.\n        destination (list or dict, optional): A list or dictionary to store the\n            flattened tree. If None an empty list will be used. Default: ``None``.\n\n    Returns:\n        Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of\n            the Python tree.\n    \"\"\"\n\ndef tree_unflatten(tree: list[tuple[str, Any]] | dict[str, Any]) -> Any:\n    \"\"\"Recreate a Python tree from its flat representation.\n\n    .. code-block:: python\n\n        from mlx.utils import tree_unflatten\n\n        d = tree_unflatten([(\"hello.world\", 42)])\n        print(d)\n        # {\"hello\": {\"world\": 42}}\n\n        d = tree_unflatten({\"hello.world\": 42})\n        print(d)\n        # {\"hello\": {\"world\": 42}}\n\n    Args:\n        tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.\n           For instance as returned by :meth:`tree_flatten`.\n\n    Returns:\n        A Python tree.\n    \"\"\"\n\ndef tree_reduce(\n    fn: Callable[[Any, Any], Any],\n    tree: list[MX_ARRAY_TREE] | tuple[MX_ARRAY_TREE, ...] | dict[str, MX_ARRAY_TREE],\n    initializer=...,\n    is_leaf=...,\n) -> None:\n    \"\"\"Applies a reduction to the leaves of a Python tree.\n\n    This function reduces Python trees into an accumulated result by applying\n    the provided function ``fn`` to the leaves of the tree.\n\n    Example:\n        >>> from mlx.utils import tree_reduce\n        >>> tree = {\"a\": [1, 2, 3], \"b\": [4, 5]}\n        >>> tree_reduce(lambda acc, x: acc + x, tree, 0)\n        15\n\n    Args:\n        fn (callable): The reducer function that takes two arguments (accumulator,\n            current value) and returns the updated accumulator.\n        tree (Any): The Python tree to reduce. It can be any nested combination of\n            lists, tuples, or dictionaries.\n        initializer (Any, optional): The initial value to start the reduction. If\n            not provided, the first leaf value is used.\n        is_leaf (callable, optional): A function to determine if an object is a\n            leaf, returning ``True`` for leaf nodes and ``False`` otherwise.\n\n    Returns:\n        Any: The accumulated value.\n    \"\"\"\n\ndef tree_merge(\n    tree_a, tree_b, merge_fn=...\n):  # -> dict[Any, Any] | list[Any] | tuple[Any, *tuple[Any, ...]] | tuple[Any, ...]:\n    \"\"\"Merge two Python trees in one containing the values of both. It can be\n    thought of as a deep dict.update method.\n\n    Args:\n        tree_a (Any): The first Python tree.\n        tree_b (Any): The second Python tree.\n        merge_fn (callable, optional): A function to merge leaves.\n\n    Returns:\n        The Python tree containing the values of both ``tree_a`` and\n        ``tree_b``.\n    \"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/__init__.pyi",
    "content": "import models as models\nimport tokenizer_utils as tokenizer_utils\nfrom generate import *\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/convert.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport argparse\nfrom typing import Callable, Optional, Union\n\nimport mlx.nn as nn\n\ndef mixed_quant_predicate_builder(\n    recipe: str, model: nn.Module, group_size: int = ...\n) -> Callable[[str, nn.Module, dict], Union[bool, dict]]: ...\n\nQUANT_RECIPES = ...\nMODEL_CONVERSION_DTYPES = ...\n\ndef convert(\n    hf_path: str,\n    mlx_path: str = ...,\n    quantize: bool = ...,\n    q_group_size: int = ...,\n    q_bits: int = ...,\n    q_mode: str = ...,\n    dtype: Optional[str] = ...,\n    upload_repo: str = ...,\n    revision: Optional[str] = ...,\n    dequantize: bool = ...,\n    quant_predicate: Optional[\n        Union[Callable[[str, nn.Module, dict], Union[bool, dict]], str]\n    ] = ...,\n    trust_remote_code: bool = ...,\n):  # -> None:\n    ...\ndef configure_parser() -> argparse.ArgumentParser:\n    \"\"\"\n    Configures and returns the argument parser for the script.\n\n    Returns:\n        argparse.ArgumentParser: Configured argument parser.\n    \"\"\"\n\ndef main():  # -> None:\n    ...\n\nif __name__ == \"__main__\": ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/generate.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport contextlib\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, Generator, List, Optional, Tuple, Union\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom transformers import PreTrainedTokenizer\n\nfrom .tokenizer_utils import TokenizerWrapper\n\nDEFAULT_PROMPT = ...\nDEFAULT_MAX_TOKENS = ...\nDEFAULT_TEMP = ...\nDEFAULT_TOP_P = ...\nDEFAULT_MIN_P = ...\nDEFAULT_TOP_K = ...\nDEFAULT_XTC_PROBABILITY = ...\nDEFAULT_XTC_THRESHOLD = ...\nDEFAULT_MIN_TOKENS_TO_KEEP = ...\nDEFAULT_SEED = ...\nDEFAULT_MODEL = ...\nDEFAULT_QUANTIZED_KV_START = ...\n\ndef str2bool(string):  # -> bool:\n    ...\ndef setup_arg_parser():  # -> ArgumentParser:\n    \"\"\"Set up and return the argument parser.\"\"\"\n\ngeneration_stream = ...\n\n@contextlib.contextmanager\ndef wired_limit(\n    model: nn.Module, streams: Optional[List[mx.Stream]] = ...\n):  # -> Generator[None, Any, None]:\n    \"\"\"\n    A context manager to temporarily change the wired limit.\n\n    Note, the wired limit should not be changed during an async eval.  If an\n    async eval could be running pass in the streams to synchronize with prior\n    to exiting the context manager.\n    \"\"\"\n@dataclass\nclass GenerationResponse:\n    \"\"\"\n    The output of :func:`stream_generate`.\n\n    Args:\n        text (str): The next segment of decoded text. This can be an empty string.\n        token (int): The next token.\n        from_draft (bool): Whether the token was generated by the draft model.\n        logprobs (mx.array): A vector of log probabilities.\n        prompt_tokens (int): The number of tokens in the prompt.\n        prompt_tps (float): The prompt processing tokens-per-second.\n        generation_tokens (int): The number of generated tokens.\n        generation_tps (float): The tokens-per-second for generation.\n        peak_memory (float): The peak memory used so far in GB.\n        finish_reason (str): The reason the response is being sent: \"length\", \"stop\" or `None`\n    \"\"\"\n\n    text: str\n    token: int\n    logprobs: mx.array\n    from_draft: bool\n    prompt_tokens: int\n    prompt_tps: float\n    generation_tokens: int\n    generation_tps: float\n    peak_memory: float\n    finish_reason: Optional[str] = ...\n\ndef maybe_quantize_kv_cache(\n    prompt_cache: Any,\n    quantized_kv_start: int | None,\n    kv_group_size: int | None,\n    kv_bits: int | None,\n) -> None: ...\ndef generate_step(\n    prompt: mx.array,\n    model: nn.Module,\n    *,\n    max_tokens: int = ...,\n    sampler: Optional[Callable[[mx.array], mx.array]] = ...,\n    logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ...,\n    max_kv_size: Optional[int] = ...,\n    prompt_cache: Optional[Any] = ...,\n    prefill_step_size: int = ...,\n    kv_bits: Optional[int] = ...,\n    kv_group_size: int = ...,\n    quantized_kv_start: int = ...,\n    prompt_progress_callback: Optional[Callable[[int], int]] = ...,\n    input_embeddings: Optional[mx.array] = ...,\n) -> Generator[Tuple[mx.array, mx.array], None, None]:\n    \"\"\"\n    A generator producing token ids based on the given prompt from the model.\n\n    Args:\n        prompt (mx.array): The input prompt.\n        model (nn.Module): The model to use for generation.\n        max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite\n          generator. Default: ``256``.\n        sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a\n          token from a vector of log probabilities. Default: ``None``.\n        logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):\n          A list of functions that take tokens and logits and return the processed\n          logits. Default: ``None``.\n        max_kv_size (int, optional): Maximum size of the key-value cache. Old\n          entries (except the first 4 tokens) will be overwritten.\n        prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if\n          provided, the cache will be updated in place.\n        prefill_step_size (int): Step size for processing the prompt.\n        kv_bits (int, optional): Number of bits to use for KV cache quantization.\n          None implies no cache quantization. Default: ``None``.\n        kv_group_size (int): Group size for KV cache quantization. Default: ``64``.\n        quantized_kv_start (int): Step to begin using a quantized KV cache.\n           when ``kv_bits`` is non-None. Default: ``0``.\n        prompt_progress_callback (Callable[[int], int]): A call-back which takes the\n           prompt tokens processed so far and the total number of prompt tokens.\n        input_embeddings (mx.array, optional): Input embeddings to use instead of or in\n          conjunction with prompt tokens. Default: ``None``.\n\n    Yields:\n        Tuple[mx.array, mx.array]: One token and a vector of log probabilities.\n    \"\"\"\n\ndef speculative_generate_step(\n    prompt: mx.array,\n    model: nn.Module,\n    draft_model: nn.Module,\n    *,\n    num_draft_tokens: int = ...,\n    max_tokens: int = ...,\n    sampler: Optional[Callable[[mx.array], mx.array]] = ...,\n    logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = ...,\n    prompt_cache: Optional[Any] = ...,\n    prefill_step_size: int = ...,\n    kv_bits: Optional[int] = ...,\n    kv_group_size: int = ...,\n    quantized_kv_start: int = ...,\n) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:\n    \"\"\"\n    A generator producing token ids based on the given prompt from the model.\n\n    Args:\n        prompt (mx.array): The input prompt.\n        model (nn.Module): The model to use for generation.\n        draft_model (nn.Module): The draft model for speculative decoding.\n        num_draft_tokens (int, optional): The number of draft tokens for\n          speculative decoding. Default: ``2``.\n        max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite\n          generator. Default: ``256``.\n        sampler (Callable[[mx.array], mx.array], optional): A sampler for sampling a\n          token from a vector of log probabilities. Default: ``None``.\n        logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):\n          A list of functions that take tokens and logits and return the processed\n          logits. Default: ``None``.\n        prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if\n          provided, the cache will be updated in place. The cache must be trimmable.\n        prefill_step_size (int): Step size for processing the prompt.\n        kv_bits (int, optional): Number of bits to use for KV cache quantization.\n          None implies no cache quantization. Default: ``None``.\n        kv_group_size (int): Group size for KV cache quantization. Default: ``64``.\n        quantized_kv_start (int): Step to begin using a quantized KV cache.\n           when ``kv_bits`` is non-None. Default: ``0``.\n\n    Yields:\n        Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,\n          and a bool indicating if the token was generated by the draft model\n    \"\"\"\n\ndef stream_generate(\n    model: nn.Module,\n    tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],\n    prompt: Union[str, mx.array, List[int]],\n    max_tokens: int = ...,\n    draft_model: Optional[nn.Module] = ...,\n    **kwargs: object,\n) -> Generator[GenerationResponse, None, None]:\n    \"\"\"\n    A generator producing text based on the given prompt from the model.\n\n    Args:\n        model (nn.Module): The model to use for generation.\n        tokenizer (PreTrainedTokenizer): The tokenizer.\n        prompt (Union[str, mx.array, List[int]]): The input prompt string or\n          integer tokens.\n        max_tokens (int): The maximum number of tokens to generate.\n          Default: ``256``.\n        draft_model (Optional[nn.Module]): An optional draft model. If provided\n          then speculative decoding is used. The draft model must use the same\n          tokenizer as the main model. Default: ``None``.\n        kwargs: The remaining options get passed to :func:`generate_step`.\n          See :func:`generate_step` for more details.\n\n    Yields:\n        GenerationResponse: An instance containing the generated text segment and\n            associated metadata. See :class:`GenerationResponse` for details.\n    \"\"\"\n\ndef generate(\n    model: nn.Module,\n    tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],\n    prompt: Union[str, List[int]],\n    verbose: bool = ...,\n    **kwargs,\n) -> str:\n    \"\"\"\n    Generate a complete response from the model.\n\n    Args:\n       model (nn.Module): The language model.\n       tokenizer (PreTrainedTokenizer): The tokenizer.\n       prompt (Union[str, List[int]]): The input prompt string or integer tokens.\n       verbose (bool): If ``True``, print tokens and timing information.\n           Default: ``False``.\n       kwargs: The remaining options get passed to :func:`stream_generate`.\n          See :func:`stream_generate` for more details.\n    \"\"\"\n@dataclass\nclass BatchStats:\n    \"\"\"\n    An data object to hold generation stats.\n\n    Args:\n        prompt_tokens (int): The number of prompt tokens processed.\n        prompt_tps (float): The prompt processing tokens-per-second.\n        prompt_time (float): The time in seconds spent in prompt processing.\n        generation_tokens (int): The number of generated tokens.\n        generation_tps (float): The tokens-per-second for generation.\n        generation_time (float): The time in seconds spent in generation .\n        peak_memory (float): The peak memory used so far in GB.\n    \"\"\"\n\n    prompt_tokens: int = ...\n    prompt_tps: float = ...\n    prompt_time: float = ...\n    generation_tokens: int = ...\n    generation_tps: float = ...\n    generation_time: float = ...\n    peak_memory: float = ...\n\n@dataclass\nclass BatchResponse:\n    \"\"\"\n    An data object to hold a batch generation response.\n\n    Args:\n        texts: (List[str]): The generated text for each prompt.\n        stats (BatchStats): Statistics about the generation.\n    \"\"\"\n\n    texts: List[str]\n    stats: BatchStats\n    caches: Optional[List[List[Any]]]\n\ndef _left_pad_prompts(prompts: Any, max_length: Optional[int] = ...) -> mx.array: ...\ndef _right_pad_prompts(prompts: Any, max_length: Optional[int] = ...) -> mx.array: ...\ndef _make_cache(\n    model: Any, left_padding: Any, max_kv_size: Optional[int]\n) -> List[Any]: ...\ndef _merge_caches(caches: Any) -> List[Any]: ...\n@dataclass\nclass Batch:\n    uids: List[int]\n    y: mx.array\n    logprobs: mx.array\n    max_tokens: List[int]\n    num_tokens: List[int]\n    cache: List[Any]\n    samplers: List[Any]\n    logits_processors: List[Any]\n    tokens: List[mx.array]\n    def __len__(self) -> int: ...\n    def filter(self, keep_idx: List[int]) -> None: ...\n    def extend(self, other: \"Batch\") -> None: ...\n    def extract_cache(self, idx: int) -> List[Any]: ...\n\nclass BatchGenerator:\n    model: Any\n    max_kv_size: Optional[int]\n    prefill_step_size: int\n    unprocessed_prompts: List[Any]\n    active_batch: Optional[Batch]\n    prompt_progress_callback: Callable[[List[Tuple[int, int, int]]], None]\n    _stats: BatchStats\n\n    @dataclass\n    class Response:\n        uid: int\n        token: int\n        logprobs: mx.array\n        finish_reason: Optional[str]\n        prompt_cache: Any\n\n    def __init__(\n        self,\n        model: Any,\n        max_tokens: int = ...,\n        stop_tokens: Optional[set[int]] = ...,\n        sampler: Optional[Callable[[mx.array], mx.array]] = ...,\n        logits_processors: Optional[\n            List[Callable[[mx.array, mx.array], mx.array]]\n        ] = ...,\n        completion_batch_size: int = ...,\n        prefill_batch_size: int = ...,\n        prefill_step_size: int = ...,\n        prompt_progress_callback: Optional[\n            Callable[[List[Tuple[int, int, int]]], None]\n        ] = ...,\n        max_kv_size: Optional[int] = ...,\n    ) -> None: ...\n    def close(self) -> None: ...\n    def insert(\n        self,\n        prompts: Any,\n        max_tokens: Union[List[int], int, None] = ...,\n        caches: Any = ...,\n        samplers: Optional[List[Any]] = ...,\n        logits_processors: Optional[List[Any]] = ...,\n    ) -> List[int]: ...\n    def remove(\n        self, uids: List[int], return_prompt_caches: bool = ...\n    ) -> Optional[dict[int, List[Any]]]: ...\n    def stats(self) -> BatchStats: ...\n    def next(self) -> List[Response]: ...\n    def _process_prompts(self, prompts: List[Any]) -> Batch: ...\n    def _step(\n        self,\n        input_tokens: mx.array,\n        prompt_cache: List[Any],\n        samplers: Optional[List[Any]],\n        logits_processors: Optional[List[Any]],\n        tokens: List[mx.array],\n    ) -> Tuple[mx.array, List[mx.array]]: ...\n\ndef batch_generate(\n    model,\n    tokenizer,\n    prompts: List[int],\n    max_tokens: Union[int, List[int]] = ...,\n    verbose: bool = ...,\n    **kwargs,\n) -> BatchResponse:\n    \"\"\"\n    Generate responses for the given batch of prompts.\n\n    Args:\n       model (nn.Module): The language model.\n       tokenizer (PreTrainedTokenizer): The tokenizer.\n       prompt (List[List[int]]): The input prompts.\n       verbose (bool): If ``True``, print tokens and timing information.\n          Default: ``False``.\n       max_tokens (Union[int, List[int]): Maximum number of output tokens. This\n          can be per prompt if a list is provided.\n       kwargs: The remaining options get passed to :obj:`BatchGenerator`.\n          See :obj:`BatchGenerator` for more details.\n    \"\"\"\n\ndef main():  # -> None:\n    ...\n\nif __name__ == \"__main__\": ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/models/__init__.pyi",
    "content": "import cache as cache\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/models/base.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport mlx.core as mx\n\n@dataclass\nclass BaseModelArgs:\n    @classmethod\n    def from_dict(cls, params):  # -> Self:\n        ...\n\ndef create_causal_mask(\n    N: int,\n    offset: int = ...,\n    window_size: Optional[int] = ...,\n    right_padding: Optional[mx.array] = ...,\n    left_padding: Optional[mx.array] = ...,\n):  # -> array:\n    ...\ndef create_attention_mask(\n    h, cache=..., window_size: Optional[int] = ..., return_array: bool = ...\n):  # -> array | Literal['causal'] | None:\n    ...\ndef create_ssm_mask(h, cache=...):  # -> None:\n    ...\ndef quantized_scaled_dot_product_attention(\n    queries: mx.array,\n    q_keys: tuple[mx.array, mx.array, mx.array],\n    q_values: tuple[mx.array, mx.array, mx.array],\n    scale: float,\n    mask: Optional[mx.array],\n    group_size: int = ...,\n    bits: int = ...,\n) -> mx.array: ...\ndef scaled_dot_product_attention(\n    queries,\n    keys,\n    values,\n    cache,\n    scale: float,\n    mask: Optional[mx.array],\n    sinks: Optional[mx.array] = ...,\n) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/models/bitlinear_layers.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport mlx.nn as nn\n\ndef bitnet_quantize(model, quantization_config: dict): ...\ndef make_bitlinear_kernel():\n    \"\"\"\n    Custom Metal kernel that performs matrix multiplication directly on\n    packed weights and scales the output. This eliminates the need to\n    store unpacked weights in memory.\n    \"\"\"\n\n_bitlinear_kernel = ...\n\nclass BitLinear(nn.Module):\n    \"\"\"\n    BitLinear module with memory-efficient weight handling.\n    \"\"\"\n    def __init__(\n        self, in_features, out_features, bias=..., invert_weight_scales=...\n    ) -> None: ...\n    def execute_matmul_kernel(self, x, packed_weights): ...\n    def __call__(self, x):  # -> array:\n        ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/models/cache.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom typing import Any, Dict, List, Optional, Protocol, Literal, Self\n\nimport mlx.nn as nn\nfrom mlx.core import array\nimport mlx.core as mx\n\nclass Cache(Protocol):\n    keys: mx.array\n    values: mx.array\n    offset: int\n    def update_and_fetch(\n        self, keys: mx.array, values: mx.array\n    ) -> tuple[mx.array, mx.array]: ...\n    @property\n    def state(self) -> tuple[mx.array | None, mx.array | None]: ...\n    @state.setter\n    def state(self, v) -> None: ...\n\ndef make_prompt_cache(\n    model: nn.Module, max_kv_size: Optional[int] = ...\n) -> List[Cache | Any]:\n    \"\"\"\n    Construct the model's cache for use in generation.\n\n    This function will defer the cache construction to the model if it has a\n    ``make_cache`` method, otherwise it will make a default KV cache.\n\n    Args:\n        model (nn.Module): The language model.\n        max_kv_size (Optional[int]): If provided and the model does not have a\n            ``make_cache`` method, a ``RotatingKVCache`` is used with a maximum\n            size of ``max_kv_size``\n    \"\"\"\n\ndef save_prompt_cache(\n    file_name: str, cache: List[Cache], metadata: Dict[str, str] = ...\n) -> None:\n    \"\"\"\n    Save a pre-computed prompt cache to a file.\n\n    Args:\n        file_name (str): The ``.safetensors`` file name.\n        cache (List[Any]): The model state.\n        metadata (Dict[str, str]): Optional metadata to save along with model\n            state.\n    \"\"\"\n\ndef load_prompt_cache(file_name: str, return_metadata=...) -> array:\n    \"\"\"\n    Load a prompt cache from a file.\n\n    Args:\n        file_name (str): The ``.safetensors`` file name.\n        return_metadata (bool): Whether or not to return metadata.\n            Default: ``False``.\n\n    Returns:\n        List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and\n            the metadata if requested.\n    \"\"\"\n\ndef can_trim_prompt_cache(cache: List[Cache]) -> bool:\n    \"\"\"\n    Check if model's cache can be trimmed.\n    \"\"\"\n\ndef trim_prompt_cache(cache: List[Cache], num_tokens: int) -> List[Cache]:\n    \"\"\"\n    Trim the model's cache by the given number of tokens.\n\n    This function will trim the cache if possible (in-place) and return the\n    number of tokens that were trimmed.\n\n    Args:\n        cache (List[Any]): The model's cache.\n        num_tokens (int): The number of tokens to trim.\n\n    Returns:\n        (int): The number of tokens that were trimmed.\n    \"\"\"\n\ndef create_attention_mask(\n    N: int, offset: int, return_array: bool, window_size: Optional[int]\n) -> array | Literal[\"causal\"] | None: ...\n\nclass _BaseCache(Cache):\n    keys: mx.array\n    values: mx.array\n    offset: int\n    @property\n    def state(self) -> tuple[mx.array | None, mx.array | None]: ...\n    @state.setter\n    def state(self, v) -> None: ...\n    @property\n    def meta_state(self) -> Literal[\"\"]: ...\n    @meta_state.setter\n    def meta_state(self, v) -> None: ...\n    def trim(self, n: int) -> int: ...\n    def is_trimmable(self) -> Literal[False]: ...\n    @classmethod\n    def from_state(cls, state, meta_state) -> Self: ...\n\nclass ConcatenateKVCache(_BaseCache):\n    \"\"\"ConcatenateKVCache the simplest KV cache implementation.\n\n    Can be used as a mock KV cache or when large blocks are being processed at\n    a time in which case KVCache isn't necessarily faster. Consider using the\n    KVCache with a larger step size before using this cache.\n    \"\"\"\n    def __init__(self) -> None: ...\n    def update_and_fetch(self, keys, values):  # -> tuple[Any | array, Any | array]:\n        ...\n    @property\n    def state(self) -> tuple[mx.array | None, mx.array | None]: ...\n    @state.setter\n    def state(self, v):  # -> None:\n        ...\n    def is_trimmable(self):  # -> Literal[True]:\n        ...\n    def trim(self, n: int) -> int: ...\n    def make_mask(self, *args, **kwargs):  # -> array | Literal['causal'] | None:\n        ...\n\nclass QuantizedKVCache(_BaseCache):\n    step = ...\n    def __init__(self, group_size: int = ..., bits: int = ...) -> None: ...\n    def update_and_fetch(self, keys, values):  # -> Any:\n        ...\n    @property\n    def state(self) -> tuple[mx.array | None, mx.array | None]: ...\n    @state.setter\n    def state(self, v):  # -> None:\n        ...\n    @property\n    def meta_state(self):  # -> tuple[str, ...]:\n        ...\n    @meta_state.setter\n    def meta_state(self, v):  # -> None:\n        ...\n    def is_trimmable(self):  # -> Literal[True]:\n        ...\n    def trim(self, n: int) -> int: ...\n    def make_mask(self, *args, **kwargs):  # -> array | Literal['causal'] | None:\n        ...\n\nclass KVCache(_BaseCache):\n    step = ...\n    def __init__(self) -> None: ...\n    def update_and_fetch(self, keys, values):  # -> tuple[array | Any, array | Any]:\n        ...\n    @property\n    def state(\n        self,\n    ) -> tuple[mx.array | None, mx.array | None]: ...\n    @state.setter\n    def state(self, v) -> None: ...\n    def is_trimmable(self):  # -> Literal[True]:\n        ...\n    def trim(self, n: int) -> int: ...\n    def to_quantized(\n        self, group_size: int = ..., bits: int = ...\n    ) -> QuantizedKVCache: ...\n    def make_mask(\n        self, *args: Any, **kwargs: Any\n    ) -> mx.array | Literal[\"causal\"] | None: ...\n\nclass RotatingKVCache(_BaseCache):\n    step = ...\n    keys: mx.array | None\n    values: mx.array | None\n    keep: int\n    max_size: int\n    _idx: int\n    def __init__(self, max_size, keep=...) -> None: ...\n    def _trim(\n        self, trim_size: int, v: mx.array, append: mx.array | None = ...\n    ) -> mx.array: ...\n    def update_and_fetch(\n        self, keys, values\n    ):  # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:\n        ...\n    @property\n    def state(\n        self,\n    ) -> tuple[mx.array | None, mx.array | None]: ...\n    @state.setter\n    def state(self, v):  # -> None:\n        ...\n    @property\n    def meta_state(self):  # -> tuple[str, ...]:\n        ...\n    @meta_state.setter\n    def meta_state(self, v):  # -> None:\n        ...\n    def is_trimmable(self):  # -> bool:\n        ...\n    def trim(self, n: int) -> int: ...\n    def to_quantized(\n        self, group_size: int = ..., bits: int = ...\n    ) -> QuantizedKVCache: ...\n    def make_mask(\n        self, N: int, window_size: Optional[int] = ..., return_array: bool = ...\n    ):  # -> array | Literal['causal'] | None:\n        ...\n\nclass ArraysCache(_BaseCache):\n    def __init__(self, size, left_padding: Optional[List[int]] = ...) -> None: ...\n    def __setitem__(self, idx, value):  # -> None:\n        ...\n    def __getitem__(self, idx): ...\n    @property\n    def state(self) -> tuple[mx.array | None, mx.array | None]: ...\n    @state.setter\n    def state(self, v):  # -> None:\n        ...\n    def filter(self, batch_indices):  # -> None:\n        \"\"\"\n        In-place filter to keep just the given indices in the cache.\n        \"\"\"\n\n    def extend(self, other):  # -> None:\n        \"\"\"\n        In-place extend this cache with the other cache.\n        \"\"\"\n\n    def make_mask(self, N: int) -> mx.array | None: ...\n\nclass MambaCache(ArraysCache):\n    def __init__(self, left_padding: Optional[List[int]] = ...) -> None: ...\n\nclass ChunkedKVCache(KVCache):\n    def __init__(self, chunk_size) -> None: ...\n    def maybe_trim_front(self):  # -> None:\n        ...\n    def update_and_fetch(self, keys, values):  # -> tuple[array, array]:\n        ...\n    def trim(self, n: int) -> int: ...\n    @property\n    def meta_state(self):  # -> tuple[str, ...]:\n        ...\n    @meta_state.setter\n    def meta_state(self, v):  # -> None:\n        ...\n\nclass CacheList(_BaseCache):\n    def __init__(self, *caches) -> None: ...\n    def __getitem__(self, idx): ...\n    def is_trimmable(self):  # -> bool:\n        ...\n    def trim(self, n: int) -> int: ...\n    @property\n    def state(self) -> list[tuple[mx.array | None, mx.array | None]]: ...\n    @state.setter\n    def state(self, v):  # -> None:\n        ...\n    def filter(self, batch_indices):  # -> None:\n        \"\"\"\n        In-place filter to keep just the given indices in the cache.\n        \"\"\"\n\n    def extend(self, other):  # -> None:\n        \"\"\"\n        In-place extend this cache with the other cache.\n        \"\"\"\n\nclass BatchKVCache(_BaseCache):\n    step = ...\n    def __init__(self, left_padding: List[int]) -> None:\n        \"\"\"\n        The BatchKV cache expects inputs to be left-padded.\n\n        E.g. the following prompts:\n\n            [1, 3, 5]\n            [7]\n            [2, 6, 8, 9]\n\n        Should be padded like so:\n\n            [0, 1, 3, 5]\n            [0, 0, 0, 7]\n            [2, 6, 8, 9]\n\n        And ``left_padding`` specifies the amount of padding for each.\n        In this case, ``left_padding = [1, 3, 0]``.\n        \"\"\"\n\n    def update_and_fetch(self, keys, values):  # -> tuple[array | Any, array | Any]:\n        ...\n    @property\n    def state(\n        self,\n    ):  # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:\n        ...\n    @state.setter\n    def state(self, v):  # -> None:\n        ...\n    def is_trimmable(self):  # -> Literal[True]:\n        ...\n    def trim(self, n):  # -> int | float:\n        ...\n    def make_mask(self, N: int, return_array: bool = ..., **kwargs):  # -> array:\n        ...\n    def filter(self, batch_indices):  # -> None:\n        \"\"\"\n        In-place filter to keep just the given indices in the cache.\n        \"\"\"\n\n    def extend(self, other):  # -> None:\n        \"\"\"\n        In-place extend this cache with the other cache.\n        \"\"\"\n\nclass BatchRotatingKVCache(_BaseCache):\n    step = ...\n    def __init__(self, max_size, left_padding: List[int]) -> None: ...\n    def update_and_fetch(\n        self, keys, values\n    ):  # -> tuple[array | Any, array | Any] | tuple[array | Any, array | Any | None]:\n        ...\n    @property\n    def state(\n        self,\n    ):  # -> tuple[Any | array | None, Any | array | None, array | Any, array | Any]:\n        ...\n    @state.setter\n    def state(self, v):  # -> None:\n        ...\n    @property\n    def meta_state(self):  # -> tuple[str, ...]:\n        ...\n    @meta_state.setter\n    def meta_state(self, v):  # -> None:\n        ...\n    def is_trimmable(self):  # -> bool:\n        ...\n    def trim(self, n):  # -> int:\n        ...\n    def to_quantized(\n        self, group_size: int = ..., bits: int = ...\n    ) -> QuantizedKVCache: ...\n    def make_mask(\n        self, N: int, window_size: Optional[int] = ..., return_array: bool = ...\n    ):  # -> array:\n        ...\n    def filter(self, batch_indices):  # -> None:\n        \"\"\"\n        In-place filter to keep just the given indices in the cache.\n        \"\"\"\n\n    def extend(self, other):  # -> None:\n        \"\"\"\n        In-place extend this cache with the other cache.\n        \"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/models/deepseek_v3.pyi",
    "content": "\"\"\"Type stubs for mlx_lm.models.deepseek_v3\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, Optional\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom mlx_lm.models.mla import MultiLinear\n\nfrom .base import BaseModelArgs\nfrom .switch_layers import SwitchGLU\n\n@dataclass\nclass ModelArgs(BaseModelArgs):\n    model_type: str\n    vocab_size: int\n    hidden_size: int\n    intermediate_size: int\n    moe_intermediate_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    num_key_value_heads: int\n    n_shared_experts: Optional[int]\n    n_routed_experts: Optional[int]\n    routed_scaling_factor: float\n    kv_lora_rank: int\n    q_lora_rank: Optional[int]\n    qk_rope_head_dim: int\n    v_head_dim: int\n    qk_nope_head_dim: int\n    topk_method: str\n    scoring_func: str\n    norm_topk_prob: bool\n    n_group: int\n    topk_group: int\n    num_experts_per_tok: int\n    moe_layer_freq: int\n    first_k_dense_replace: int\n    max_position_embeddings: int\n    rms_norm_eps: float\n    rope_theta: float\n    rope_scaling: Optional[Dict[str, Any]]\n    attention_bias: bool\n\nclass DeepseekV3Attention(nn.Module):\n    config: ModelArgs\n    hidden_size: int\n    num_heads: int\n    max_position_embeddings: int\n    rope_theta: float\n    q_lora_rank: Optional[int]\n    qk_rope_head_dim: int\n    kv_lora_rank: int\n    v_head_dim: int\n    qk_nope_head_dim: int\n    q_head_dim: int\n    scale: float\n    q_proj: nn.Linear\n    q_a_proj: nn.Linear\n    q_a_layernorm: nn.RMSNorm\n    q_b_proj: nn.Linear\n    kv_a_proj_with_mqa: nn.Linear\n    kv_a_layernorm: nn.RMSNorm\n    # kv_b_proj: nn.Linear\n    embed_q: MultiLinear\n    unembed_out: MultiLinear\n\n    o_proj: nn.Linear\n    rope: Any\n\n    def __init__(self, config: ModelArgs) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass DeepseekV3MLP(nn.Module):\n    config: ModelArgs\n    hidden_size: int\n    intermediate_size: int\n    gate_proj: nn.Linear\n    up_proj: nn.Linear\n    down_proj: nn.Linear\n\n    def __init__(\n        self,\n        config: ModelArgs,\n        hidden_size: Optional[int] = None,\n        intermediate_size: Optional[int] = None,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass MoEGate(nn.Module):\n    config: ModelArgs\n    top_k: int\n    norm_topk_prob: bool\n    n_routed_experts: Optional[int]\n    routed_scaling_factor: float\n    n_group: int\n    topk_group: int\n    weight: mx.array\n    e_score_correction_bias: mx.array\n\n    def __init__(self, config: ModelArgs) -> None: ...\n    def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...\n\nclass DeepseekV3MoE(nn.Module):\n    config: ModelArgs\n    num_experts_per_tok: int\n    switch_mlp: SwitchGLU\n    gate: MoEGate\n    shared_experts: DeepseekV3MLP\n    sharding_group: Optional[mx.distributed.Group]\n\n    def __init__(self, config: ModelArgs) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass DeepseekV3DecoderLayer(nn.Module):\n    self_attn: DeepseekV3Attention\n    mlp: DeepseekV3MLP | DeepseekV3MoE\n    input_layernorm: nn.RMSNorm\n    post_attention_layernorm: nn.RMSNorm\n\n    def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass DeepseekV3Model(nn.Module):\n    vocab_size: int\n    embed_tokens: nn.Embedding\n    layers: list[DeepseekV3DecoderLayer]\n    norm: nn.RMSNorm\n\n    def __init__(self, config: ModelArgs) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass Model(nn.Module):\n    model_type: str\n    model: DeepseekV3Model\n    lm_head: nn.Linear\n\n    def __init__(self, config: ModelArgs) -> None: ...\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n    def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...\n    @property\n    def layers(self) -> list[DeepseekV3DecoderLayer]: ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/models/glm4_moe.pyi",
    "content": "from dataclasses import dataclass\nfrom typing import Any, Dict, Optional\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom .base import BaseModelArgs\nfrom .switch_layers import SwitchGLU\n\n@dataclass\nclass ModelArgs(BaseModelArgs):\n    model_type: str\n    vocab_size: int\n    hidden_size: int\n    intermediate_size: int\n    max_position_embeddings: int\n    moe_intermediate_size: int\n    norm_topk_prob: bool\n    num_attention_heads: int\n    n_group: int\n    head_dim: int\n    topk_group: int\n    n_shared_experts: int\n    n_routed_experts: int\n    routed_scaling_factor: float\n    num_experts_per_tok: int\n    first_k_dense_replace: int\n    num_hidden_layers: int\n    num_key_value_heads: int\n    rms_norm_eps: float\n    rope_theta: float\n    rope_scaling: Optional[Dict[str, Any]]\n    use_qk_norm: bool\n    tie_word_embeddings: bool\n    attention_bias: bool\n    partial_rotary_factor: float\n    scoring_func: str\n    topk_method: str\n\nclass Attention(nn.Module):\n    n_heads: int\n    n_kv_heads: int\n    scale: float\n    q_proj: nn.Linear\n    k_proj: nn.Linear\n    v_proj: nn.Linear\n    o_proj: nn.Linear\n    use_qk_norm: bool\n    q_norm: nn.RMSNorm\n    k_norm: nn.RMSNorm\n    rope: nn.RoPE\n\n    def __init__(self, args: ModelArgs) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass MLP(nn.Module):\n    config: ModelArgs\n    hidden_size: int\n    intermediate_size: int\n    gate_proj: nn.Linear\n    up_proj: nn.Linear\n    down_proj: nn.Linear\n\n    def __init__(\n        self,\n        config: ModelArgs,\n        hidden_size: Optional[int] = None,\n        intermediate_size: Optional[int] = None,\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass MoEGate(nn.Module):\n    config: ModelArgs\n    top_k: int\n    norm_topk_prob: bool\n    n_routed_experts: int\n    routed_scaling_factor: float\n    n_group: int\n    topk_group: int\n    weight: mx.array\n    e_score_correction_bias: mx.array\n\n    def __init__(self, config: ModelArgs) -> None: ...\n    def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...\n\nclass MoE(nn.Module):\n    config: ModelArgs\n    num_experts_per_tok: int\n    switch_mlp: SwitchGLU\n    gate: MoEGate\n    shared_experts: MLP\n    sharding_group: Optional[mx.distributed.Group]\n\n    def __init__(self, config: ModelArgs) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass DecoderLayer(nn.Module):\n    self_attn: Attention\n    mlp: MLP | MoE\n    input_layernorm: nn.RMSNorm\n    post_attention_layernorm: nn.RMSNorm\n\n    def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass LanguageModel(nn.Module):\n    vocab_size: int\n    embed_tokens: nn.Embedding\n    layers: list[DecoderLayer]\n    norm: nn.RMSNorm\n    pipeline_rank: int\n    pipeline_size: int\n    start_idx: int\n    end_idx: Optional[int]\n    num_layers: int\n\n    def __init__(self, config: ModelArgs) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n    @property\n    def pipeline_layers(self) -> list[DecoderLayer]: ...\n\nclass Model(nn.Module):\n    args: ModelArgs\n    model_type: str\n    model: LanguageModel\n    lm_head: nn.Linear\n\n    def __init__(self, config: ModelArgs) -> None: ...\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n    def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...\n    def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...\n    @property\n    def layers(self) -> list[DecoderLayer]: ...\n    @property\n    def cast_predicate(self) -> Any: ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/models/glm_moe_dsa.pyi",
    "content": "\"\"\"Type stubs for mlx_lm.models.glm_moe_dsa\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, Optional\n\nfrom .base import BaseModelArgs\nfrom .deepseek_v32 import Model as DSV32Model\n\n@dataclass\nclass ModelArgs(BaseModelArgs):\n    model_type: str\n    vocab_size: int\n    hidden_size: int\n    index_head_dim: int\n    index_n_heads: int\n    index_topk: int\n    intermediate_size: int\n    moe_intermediate_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    num_key_value_heads: int\n    n_shared_experts: Optional[int]\n    n_routed_experts: Optional[int]\n    routed_scaling_factor: float\n    kv_lora_rank: int\n    q_lora_rank: int\n    qk_rope_head_dim: int\n    v_head_dim: int\n    qk_nope_head_dim: int\n    topk_method: str\n    scoring_func: str\n    norm_topk_prob: bool\n    n_group: int\n    topk_group: int\n    num_experts_per_tok: int\n    moe_layer_freq: int\n    first_k_dense_replace: int\n    max_position_embeddings: int\n    rms_norm_eps: float\n    rope_parameters: Dict[str, Any]\n    attention_bias: bool\n    rope_scaling: Dict[str, Any] | None\n    rope_theta: float | None\n\nclass Model(DSV32Model):\n    def __init__(self, config: ModelArgs) -> None: ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/models/nemotron_h.pyi",
    "content": "from dataclasses import dataclass\nfrom typing import Any, List, Optional, Tuple\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom .cache import ArraysCache, KVCache\nfrom .switch_layers import SwitchMLP\n\n@dataclass\nclass ModelArgs:\n    model_type: str\n    vocab_size: int\n    hidden_size: int\n    intermediate_size: int\n    num_hidden_layers: int\n    max_position_embeddings: int\n    num_attention_heads: int\n    num_key_value_heads: int\n    attention_bias: bool\n    mamba_num_heads: int\n    mamba_head_dim: int\n    mamba_proj_bias: bool\n    ssm_state_size: int\n    conv_kernel: int\n    n_groups: int\n    mlp_bias: bool\n    layer_norm_epsilon: float\n    use_bias: bool\n    use_conv_bias: bool\n    hybrid_override_pattern: List[str]\n    head_dim: Optional[int]\n    moe_intermediate_size: Optional[int]\n    moe_shared_expert_intermediate_size: Optional[int]\n    n_group: Optional[int]\n    n_routed_experts: Optional[int]\n    n_shared_experts: Optional[int]\n    topk_group: Optional[int]\n    num_experts_per_tok: Optional[int]\n    norm_topk_prob: Optional[bool]\n    routed_scaling_factor: Optional[float]\n    time_step_limit: Optional[Tuple[float, float]]\n    time_step_min: Optional[float]\n    time_step_max: Optional[float]\n\n    @classmethod\n    def from_dict(cls, params: dict[str, Any]) -> ModelArgs: ...\n    def __post_init__(self) -> None: ...\n\nclass NemotronHMamba2Mixer(nn.Module):\n    num_heads: int\n    hidden_size: int\n    ssm_state_size: int\n    conv_kernel_size: int\n    intermediate_size: int\n    n_groups: int\n    head_dim: int\n    conv_dim: int\n    conv1d: nn.Conv1d\n    in_proj: nn.Linear\n    dt_bias: mx.array\n    A_log: mx.array\n    D: mx.array\n    norm: nn.RMSNorm\n    heads_per_group: int\n    out_proj: nn.Linear\n\n    def __init__(self, args: ModelArgs) -> None: ...\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        mask: Optional[mx.array],\n        cache: Optional[ArraysCache] = None,\n    ) -> mx.array: ...\n\nclass NemotronHAttention(nn.Module):\n    hidden_size: int\n    num_heads: int\n    head_dim: int\n    num_key_value_heads: int\n    scale: float\n    q_proj: nn.Linear\n    k_proj: nn.Linear\n    v_proj: nn.Linear\n    o_proj: nn.Linear\n\n    def __init__(self, args: ModelArgs) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[KVCache] = None,\n    ) -> mx.array: ...\n\nclass NemotronHMLP(nn.Module):\n    up_proj: nn.Linear\n    down_proj: nn.Linear\n\n    def __init__(\n        self, args: ModelArgs, intermediate_size: Optional[int] = None\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass NemotronHMoE(nn.Module):\n    num_experts_per_tok: int\n    switch_mlp: SwitchMLP\n    shared_experts: NemotronHMLP\n\n    def __init__(self, config: ModelArgs) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass NemotronHBlock(nn.Module):\n    block_type: str\n    norm: nn.RMSNorm\n    mixer: NemotronHMamba2Mixer | NemotronHAttention | NemotronHMLP | NemotronHMoE\n\n    def __init__(self, args: ModelArgs, block_type: str) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass NemotronHModel(nn.Module):\n    embeddings: nn.Embedding\n    layers: list[NemotronHBlock]\n    norm_f: nn.RMSNorm\n    fa_idx: int\n    ssm_idx: int\n\n    def __init__(self, args: ModelArgs) -> None: ...\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass Model(nn.Module):\n    args: ModelArgs\n    backbone: NemotronHModel\n    lm_head: nn.Linear\n    model_type: str\n\n    def __init__(self, args: ModelArgs) -> None: ...\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n    @property\n    def layers(self) -> list[NemotronHBlock]: ...\n    def make_cache(self) -> list[ArraysCache | KVCache]: ...\n    def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/models/qwen3_5.pyi",
    "content": "from dataclasses import dataclass\nfrom typing import Any, Optional\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom .cache import ArraysCache, KVCache\nfrom .qwen3_next import (\n    Qwen3NextAttention as Attention,\n    Qwen3NextMLP as MLP,\n    Qwen3NextRMSNormGated as RMSNormGated,\n    Qwen3NextSparseMoeBlock,\n)\n\nSparseMoeBlock = Qwen3NextSparseMoeBlock\nfrom .switch_layers import SwitchGLU\n\n@dataclass\nclass TextModelArgs:\n    model_type: str\n    hidden_size: int\n    intermediate_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    rms_norm_eps: float\n    vocab_size: int\n    num_key_value_heads: int\n    max_position_embeddings: int\n    linear_num_value_heads: int\n    linear_num_key_heads: int\n    linear_key_head_dim: int\n    linear_value_head_dim: int\n    linear_conv_kernel_dim: int\n    tie_word_embeddings: bool\n    attention_bias: bool\n    head_dim: Optional[int]\n    full_attention_interval: int\n    num_experts: int\n    num_experts_per_tok: int\n    decoder_sparse_step: int\n    shared_expert_intermediate_size: int\n    moe_intermediate_size: int\n    norm_topk_prob: bool\n    rope_parameters: Optional[dict[str, Any]]\n    partial_rotary_factor: float\n    rope_theta: float\n    rope_scaling: Optional[dict[str, Any]]\n\n    @classmethod\n    def from_dict(cls, params: dict[str, Any]) -> TextModelArgs: ...\n    def __post_init__(self) -> None: ...\n\nclass GatedDeltaNet(nn.Module):\n    hidden_size: int\n    num_v_heads: int\n    num_k_heads: int\n    head_k_dim: int\n    head_v_dim: int\n    key_dim: int\n    value_dim: int\n    conv_kernel_size: int\n    conv_dim: int\n    conv1d: nn.Conv1d\n    in_proj_qkv: nn.Linear\n    in_proj_z: nn.Linear\n    in_proj_b: nn.Linear\n    in_proj_a: nn.Linear\n    dt_bias: mx.array\n    A_log: mx.array\n    norm: RMSNormGated\n    out_proj: nn.Linear\n\n    def __init__(self, config: TextModelArgs) -> None: ...\n    def __call__(\n        self,\n        inputs: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass DecoderLayer(nn.Module):\n    is_linear: bool\n    linear_attn: GatedDeltaNet\n    self_attn: Attention\n    input_layernorm: nn.RMSNorm\n    post_attention_layernorm: nn.RMSNorm\n    mlp: MLP | SparseMoeBlock\n\n    def __init__(self, args: TextModelArgs, layer_idx: int) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass Qwen3_5TextModel(nn.Module):\n    embed_tokens: nn.Embedding\n    layers: list[DecoderLayer]\n    norm: nn.RMSNorm\n    ssm_idx: int\n    fa_idx: int\n\n    def __init__(self, args: TextModelArgs) -> None: ...\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache: Optional[Any] = None,\n        input_embeddings: Optional[mx.array] = None,\n    ) -> mx.array: ...\n\nclass TextModel(nn.Module):\n    args: TextModelArgs\n    model_type: str\n    model: Qwen3_5TextModel\n    lm_head: nn.Linear\n\n    def __init__(self, args: TextModelArgs) -> None: ...\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache: Optional[Any] = None,\n        input_embeddings: Optional[mx.array] = None,\n    ) -> mx.array: ...\n    @property\n    def layers(self) -> list[DecoderLayer]: ...\n    def make_cache(self) -> list[ArraysCache | KVCache]: ...\n    def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...\n\n@dataclass\nclass ModelArgs:\n    model_type: str\n    text_config: dict[str, Any]\n\n    @classmethod\n    def from_dict(cls, params: dict[str, Any]) -> ModelArgs: ...\n\nclass Model(nn.Module):\n    args: ModelArgs\n    model_type: str\n    language_model: TextModel\n\n    def __init__(self, args: ModelArgs) -> None: ...\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache: Optional[Any] = None,\n        input_embeddings: Optional[mx.array] = None,\n    ) -> mx.array: ...\n    def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...\n    @property\n    def layers(self) -> list[DecoderLayer]: ...\n    def make_cache(self) -> list[ArraysCache | KVCache]: ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/models/qwen3_5_moe.pyi",
    "content": "from dataclasses import dataclass\nfrom typing import Any, Optional\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom .cache import ArraysCache, KVCache\nfrom .qwen3_5 import DecoderLayer, Model as Qwen3_5Model, TextModel\n\n@dataclass\nclass ModelArgs:\n    model_type: str\n    text_config: dict[str, Any]\n\n    @classmethod\n    def from_dict(cls, params: dict[str, Any]) -> ModelArgs: ...\n\nclass Model(Qwen3_5Model):\n    def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/models/qwen3_next.pyi",
    "content": "\"\"\"Type stubs for mlx_lm.models.qwen3_next\"\"\"\n\nfrom typing import Any, Optional\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom .cache import ArraysCache, KVCache\nfrom .switch_layers import SwitchGLU\n\nclass Qwen3NextRMSNormGated(nn.Module):\n    eps: float\n    weight: mx.array\n\n    def __init__(self, hidden_size: int, eps: float = ...) -> None: ...\n    def __call__(\n        self, hidden_states: mx.array, gate: mx.array | None = None\n    ) -> mx.array: ...\n\nclass Qwen3NextMLP(nn.Module):\n    gate_proj: nn.Linear\n    down_proj: nn.Linear\n    up_proj: nn.Linear\n\n    def __init__(self, dim: int, hidden_dim: int) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass Qwen3NextGatedDeltaNet(nn.Module):\n    hidden_size: int\n    num_v_heads: int\n    num_k_heads: int\n    head_k_dim: int\n    head_v_dim: int\n    key_dim: int\n    value_dim: int\n    conv_kernel_size: int\n    conv_dim: int\n    conv1d: nn.Conv1d\n    in_proj_qkvz: nn.Linear\n    in_proj_ba: nn.Linear\n    dt_bias: mx.array\n    A_log: mx.array\n    out_proj: nn.Linear\n\n    def __init__(self, config: Any) -> None: ...\n    def __call__(\n        self,\n        inputs: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass Qwen3NextAttention(nn.Module):\n    num_attention_heads: int\n    num_key_value_heads: int\n    head_dim: int\n    scale: float\n    q_proj: nn.Linear\n    k_proj: nn.Linear\n    v_proj: nn.Linear\n    o_proj: nn.Linear\n\n    def __init__(self, args: Any) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass Qwen3NextSparseMoeBlock(nn.Module):\n    norm_topk_prob: bool\n    num_experts: int\n    top_k: int\n    gate: nn.Linear\n    switch_mlp: SwitchGLU\n    shared_expert: Qwen3NextMLP\n    shared_expert_gate: nn.Linear\n\n    def __init__(self, args: Any) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass Qwen3NextDecoderLayer(nn.Module):\n    is_linear: bool\n    linear_attn: Qwen3NextGatedDeltaNet\n    self_attn: Qwen3NextAttention\n    input_layernorm: nn.RMSNorm\n    post_attention_layernorm: nn.RMSNorm\n    mlp: Qwen3NextMLP | Qwen3NextSparseMoeBlock\n\n    def __init__(self, args: Any, layer_idx: int) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass Qwen3NextModel(nn.Module):\n    embed_tokens: nn.Embedding\n    layers: list[Qwen3NextDecoderLayer]\n    norm: nn.RMSNorm\n    ssm_idx: int\n    fa_idx: int\n\n    def __init__(self, args: Any) -> None: ...\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass Model(nn.Module):\n    model_type: str\n    model: Qwen3NextModel\n    lm_head: nn.Linear\n\n    def __init__(self, args: Any) -> None: ...\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n    def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...\n    @property\n    def layers(self) -> list[Qwen3NextDecoderLayer]: ...\n    def make_cache(self) -> list[ArraysCache | KVCache]: ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/models/step3p5.pyi",
    "content": "from dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom .base import BaseModelArgs\nfrom .switch_layers import SwitchGLU\n\n@dataclass\nclass ModelArgs(BaseModelArgs):\n    model_type: str\n    hidden_size: int\n    num_hidden_layers: int\n    vocab_size: int\n    num_attention_heads: int\n    num_attention_groups: int\n    head_dim: int\n    intermediate_size: int\n    rms_norm_eps: float\n    rope_theta: float\n    rope_scaling: Optional[Dict[str, Any]]\n    max_position_embeddings: int\n    sliding_window: int\n    layer_types: Optional[List[str]]\n    yarn_only_types: Optional[List[str]]\n    partial_rotary_factors: Optional[List[float]]\n    attention_other_setting: Optional[Dict[str, Any]]\n    use_head_wise_attn_gate: bool\n    moe_num_experts: int\n    moe_top_k: int\n    moe_intermediate_size: int\n    share_expert_dim: int\n    moe_layers_enum: Optional[str]\n    moe_router_scaling_factor: float\n    norm_expert_weight: bool\n    swiglu_limits: Optional[List[float]]\n    swiglu_limits_shared: Optional[List[float]]\n    tie_word_embeddings: bool\n\nclass Step3p5MLP(nn.Module):\n    hidden_size: int\n    intermediate_size: int\n    gate_proj: nn.Linear\n    up_proj: nn.Linear\n    down_proj: nn.Linear\n    limit: Optional[float]\n\n    def __init__(\n        self, args: ModelArgs, intermediate_size: int, swiglu_limit: float = 0\n    ) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass Step3p5MoEGate(nn.Module):\n    top_k: int\n    n_routed_experts: int\n    routed_scaling_factor: float\n    norm_topk_prob: bool\n    gate: nn.Linear\n    router_bias: mx.array\n\n    def __init__(self, args: ModelArgs) -> None: ...\n    def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...\n\nclass Step3p5MoE(nn.Module):\n    gate: Step3p5MoEGate\n    switch_mlp: SwitchGLU\n    share_expert: Step3p5MLP\n    sharding_group: Optional[mx.distributed.Group]\n\n    def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...\n    def __call__(self, x: mx.array) -> mx.array: ...\n\nclass Step3p5Attention(nn.Module):\n    is_sliding: bool\n    num_heads: int\n    num_kv_heads: int\n    head_dim: int\n    scale: float\n    q_proj: nn.Linear\n    k_proj: nn.Linear\n    v_proj: nn.Linear\n    o_proj: nn.Linear\n    q_norm: nn.Module\n    k_norm: nn.Module\n    use_head_wise_attn_gate: bool\n    g_proj: nn.Linear\n    rope: nn.Module\n\n    def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass Step3p5DecoderLayer(nn.Module):\n    self_attn: Step3p5Attention\n    is_sliding: bool\n    is_moe_layer: bool\n    mlp: Step3p5MLP | Step3p5MoE\n    input_layernorm: nn.Module\n    post_attention_layernorm: nn.Module\n\n    def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Any] = None,\n    ) -> mx.array: ...\n\nclass Step3p5Model(nn.Module):\n    args: ModelArgs\n    vocab_size: int\n    num_layers: int\n    embed_tokens: nn.Embedding\n    layers: list[Step3p5DecoderLayer]\n    norm: nn.Module\n    _swa_idx: Optional[int]\n    _full_idx: Optional[int]\n\n    def __init__(self, args: ModelArgs) -> None: ...\n    def __call__(\n        self,\n        x: mx.array,\n        cache: Optional[List[Any]] = None,\n    ) -> mx.array: ...\n\nclass Model(nn.Module):\n    args: ModelArgs\n    model_type: str\n    model: Step3p5Model\n    lm_head: nn.Linear\n\n    def __init__(self, args: ModelArgs) -> None: ...\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache: Optional[List[Any]] = None,\n    ) -> mx.array: ...\n    def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...\n    def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...\n    @property\n    def layers(self) -> list[Step3p5DecoderLayer]: ...\n    def make_cache(self) -> list[Any]: ...\n    @property\n    def cast_predicate(self) -> Any: ...\n    @property\n    def quant_predicate(self) -> Any: ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/models/switch_layers.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom functools import partial\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nclass QuantizedSwitchLinear(nn.Module):\n    def __init__(\n        self,\n        input_dims: int,\n        output_dims: int,\n        num_experts: int,\n        bias: bool = ...,\n        group_size: int = ...,\n        bits: int = ...,\n        mode: str = ...,\n    ) -> None: ...\n    @property\n    def input_dims(self):  # -> int:\n        ...\n    @property\n    def output_dims(self):  # -> int:\n        ...\n    @property\n    def num_experts(self):  # -> int:\n        ...\n    def __call__(self, x, indices, sorted_indices=...):  # -> array:\n        ...\n\nclass SwitchLinear(nn.Module):\n    def __init__(\n        self, input_dims: int, output_dims: int, num_experts: int, bias: bool = ...\n    ) -> None: ...\n    @property\n    def input_dims(self):  # -> int:\n        ...\n    @property\n    def output_dims(self):  # -> int:\n        ...\n    @property\n    def num_experts(self):  # -> int:\n        ...\n    def __call__(self, x, indices, sorted_indices=...): ...\n    def to_quantized(\n        self, group_size: int = ..., bits: int = ..., mode: str = ...\n    ):  # -> QuantizedSwitchLinear:\n        ...\n\n@partial(mx.compile, shapeless=True)\ndef swiglu(x, gate): ...\n\nclass SwiGLU(nn.Module):\n    def __init__(self) -> None: ...\n    def __call__(self, x, gate): ...\n\nclass SwitchGLU(nn.Module):\n    gate_proj: SwitchLinear\n    up_proj: SwitchLinear\n    down_proj: SwitchLinear\n    activation: SwiGLU\n\n    def __init__(\n        self,\n        input_dims: int,\n        hidden_dims: int,\n        num_experts: int,\n        activation=...,\n        bias: bool = ...,\n    ) -> None: ...\n    def __call__(self, x, indices) -> mx.array: ...\n\nclass SwitchMLP(nn.Module):\n    fc1: SwitchLinear\n    fc2: SwitchLinear\n\n    def __init__(\n        self,\n        input_dims: int,\n        hidden_dims: int,\n        num_experts: int,\n        activation=...,\n        bias: bool = ...,\n    ) -> None: ...\n    def __call__(self, x, indices) -> mx.array: ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/sample_utils.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom functools import partial\nfrom typing import Callable, Dict, List, Optional\n\nimport mlx.core as mx\n\ndef make_sampler(\n    temp: float = ...,\n    top_p: float = ...,\n    min_p: float = ...,\n    min_tokens_to_keep: int = ...,\n    top_k: int = ...,\n    xtc_probability: float = ...,\n    xtc_threshold: float = ...,\n    xtc_special_tokens: List[int] = ...,\n) -> Callable[[mx.array], mx.array]:\n    \"\"\"\n    Make a sampler function for use with ``generate_step``.\n\n    Args:\n        temp (float): The temperature for sampling, if 0 the argmax is used.\n          Default: ``0``.\n        top_p (float, optional): Nulceus sampling, higher means model considers\n          more less likely words.\n        min_p (float, optional): The minimum value (scaled by the top token's\n          probability) that a token probability must have to be considered.\n        min_tokens_to_keep (int, optional): Minimum number of tokens that cannot\n          be filtered by min_p sampling.\n        top_k (int, optional): The top k tokens ranked by probability to constrain\n          the sampling to.\n        xtc_probability (float, optional): The probability of applying XTC\n            sampling.\n        xtc_threshold (float, optional): The threshold the probs need to reach\n            for being sampled.\n        xtc_special_tokens (list(int), optional): List of special tokens IDs to\n            be excluded from XTC sampling.\n\n\n    Returns:\n        Callable[mx.array, mx.array]:\n            A sampler which takes log-probabilities and returns tokens.\n    \"\"\"\n\ndef make_logits_processors(\n    logit_bias: Optional[Dict[int, float]] = ...,\n    repetition_penalty: Optional[float] = ...,\n    repetition_context_size: Optional[int] = ...,\n) -> list[Callable[[mx.array, mx.array], mx.array]]:\n    \"\"\"\n    Make logits processors for use with ``generate_step``.\n\n    Args:\n        repetition_penalty (float, optional): The penalty factor for repeating\n          tokens.\n        repetition_context_size (int, optional): The number of tokens to\n          consider for repetition penalty. Default: ``20``.\n        logit_bias (dictionary, optional): Additive logit bias.\n\n    Returns:\n        List[Callable[[mx.array, mx.array], mx.array]]:\n            A list of logits processors. Each processor in the list is a\n            callable which takes an array of tokens and an array of logits\n            and returns the updated logits.\n    \"\"\"\n\n@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)\ndef apply_top_k(logprobs: mx.array, top_k: int) -> mx.array:\n    \"\"\"\n    Sample from only the top K tokens ranked by probability.\n\n    Args:\n        logprobs: A vector of log probabilities.\n        top_k (int): Top k tokens to sample from.\n    \"\"\"\n\n@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)\ndef apply_min_p(\n    logprobs: mx.array, min_p: float, min_tokens_to_keep: int = ...\n) -> mx.array:\n    \"\"\"\n    Apply min-p sampling to the logprobs.\n\n    Min-p keeps all tokens that are above a minimum probability, scaled by the\n    probability of the most likely token. As a result, the filter is more\n    aggressive given a very high-probability token.\n\n    Args:\n        logprobs: A vector of log probabilities.\n        min_p (float): Minimum token probability. Typical values are in the\n            0.01-0.2 range, comparably selective as setting `top_p` in the\n            0.99-0.8 range.\n        min_tokens_to_keep (int, optional): Minimum number of tokens that cannot\n            be filtered. Default: ``1``.\n\n    \"\"\"\n\n@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)\ndef apply_top_p(logprobs: mx.array, top_p: float) -> mx.array:\n    \"\"\"\n    Apply top-p (nucleus) sampling to logits.\n\n    Args:\n        logprobs: A vector of log probabilities.\n        top_p: The cumulative probability threshold for top-p filtering.\n    Returns:\n        token selected based on the top-p criterion.\n    \"\"\"\n\n@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)\ndef apply_xtc(\n    logits: mx.array,\n    xtc_probability: float,\n    xtc_threshold: float,\n    xtc_special_tokens: List[int],\n) -> mx.array:\n    \"\"\"\n    Apply XTC sampling to the logits.\n\n    Args:\n        logits: The logits from the model's output.\n        xtc_probability (float): Probability of XTC sampling to happen for each token\n        xtc_threshold (float): The threshold the probs need to reach for being sampled.\n        special_tokens_ids (list(int)): List of special tokens IDs to be excluded from XTC sampling.\n    \"\"\"\n\n@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)\ndef categorical_sampling(logits, temp):  # -> array:\n    ...\ndef make_repetition_penalty(\n    penalty: float, context_size: int = ...\n):  # -> Callable[..., Any]:\n    \"\"\"\n    Make repetition penalty processor.\n\n    Paper: https://arxiv.org/abs/1909.05858\n\n    Args:\n        penalty (float): The repetition penalty factor to be applied.\n        context_size (int): The number of previous tokens to use.\n            Default: ``20``.\n\n    Returns:\n        Callable[[mx.array, List[int]], mx.array]:\n            The repetition penalty processor.\n    \"\"\"\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/tokenizer_utils.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Any\n\nfrom transformers import PreTrainedTokenizerFast\n\nclass StreamingDetokenizer:\n    \"\"\"The streaming detokenizer interface so that we can detokenize one token at a time.\n\n    Example usage is as follows:\n\n        detokenizer = ...\n\n        # Reset the tokenizer state\n        detokenizer.reset()\n\n        for token in generate(...):\n            detokenizer.add_token(token.item())\n\n            # Contains the whole text so far. Some tokens may not be included\n            # since it contains whole words usually.\n            detokenizer.text\n\n            # Contains the printable segment (usually a word) since the last\n            # time it was accessed\n            detokenizer.last_segment\n\n            # Contains all the tokens added so far\n            detokenizer.tokens\n\n        # Make sure that we detokenize any remaining tokens\n        detokenizer.finalize()\n\n        # Now detokenizer.text should match tokenizer.decode(detokenizer.tokens)\n    \"\"\"\n\n    __slots__ = ...\n    def reset(self) -> None: ...\n    def add_token(self, token: int) -> None: ...\n    def finalize(self) -> None: ...\n    @property\n    def last_segment(self) -> str:\n        \"\"\"Return the last segment of readable text since last time this property was accessed.\"\"\"\n\nclass NaiveStreamingDetokenizer(StreamingDetokenizer):\n    \"\"\"NaiveStreamingDetokenizer relies on the underlying tokenizer\n    implementation and should work with every tokenizer.\n\n    Its complexity is O(T^2) where T is the longest line since it will\n    repeatedly detokenize the same tokens until a new line is generated.\n    \"\"\"\n    def __init__(self, tokenizer) -> None: ...\n    def reset(self):  # -> None:\n        ...\n    def add_token(self, token):  # -> None:\n        ...\n    def finalize(self):  # -> None:\n        ...\n    @property\n    def text(self):  # -> str:\n        ...\n\nclass SPMStreamingDetokenizer(StreamingDetokenizer):\n    \"\"\"A streaming detokenizer for SPM models.\n\n    It adds tokens to the text if the next token starts with the special SPM\n    underscore which results in linear complexity.\n    \"\"\"\n    def __init__(self, tokenizer, trim_space=...) -> None: ...\n    def reset(self):  # -> None:\n        ...\n    def add_token(self, token):  # -> None:\n        ...\n    def finalize(self):  # -> None:\n        ...\n\nclass BPEStreamingDetokenizer(StreamingDetokenizer):\n    \"\"\"A streaming detokenizer for OpenAI style BPE models.\n\n    It adds tokens to the text if the next token starts with a space similar to\n    the SPM detokenizer.\n    \"\"\"\n\n    _byte_decoder = ...\n    _space_matches = ...\n    def __init__(self, tokenizer) -> None: ...\n    def reset(self):  # -> None:\n        ...\n    def add_token(self, token):  # -> None:\n        ...\n    def finalize(self):  # -> None:\n        ...\n    @classmethod\n    def make_byte_decoder(cls):  # -> None:\n        \"\"\"See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale.\"\"\"\n\nclass TokenizerWrapper:\n    \"\"\"A wrapper that combines an HF tokenizer and a detokenizer.\n\n    Accessing any attribute other than the ``detokenizer`` is forwarded to the\n    huggingface tokenizer.\n    \"\"\"\n\n    _tokenizer: PreTrainedTokenizerFast\n    eos_token_id: int | None\n    eos_token: str | None\n    eos_token_ids: list[int] | set[int] | None\n    bos_token_id: int | None\n    bos_token: str | None\n    vocab_size: int\n    all_special_tokens: list[str]\n    think_start: str | None\n    think_end: str | None\n    think_start_id: int | None\n    think_end_id: int | None\n\n    def __init__(\n        self,\n        tokenizer: Any,\n        detokenizer_class: Any = ...,\n        eos_token_ids: list[int] | set[int] | None = ...,\n        chat_template: Any = ...,\n        tool_parser: Any = ...,\n        tool_call_start: str | None = ...,\n        tool_call_end: str | None = ...,\n    ) -> None: ...\n    def encode(self, text: str, **kwargs: Any) -> list[int]: ...\n    def decode(self, token_ids: list[int], **kwargs: Any) -> str: ...\n    def apply_chat_template(\n        self,\n        messages: list[dict[str, Any]],\n        tokenize: bool = False,\n        add_generation_prompt: bool = False,\n        tools: Any = None,\n        **kwargs: Any,\n    ) -> str: ...\n    def get_vocab(self) -> dict[str, int]: ...\n    def add_eos_token(self, token: str) -> None: ...\n    @property\n    def has_thinking(self) -> bool: ...\n    @property\n    def think_start(self) -> str | None: ...\n    @property\n    def think_end(self) -> str | None: ...\n    @property\n    def has_tool_calling(self) -> bool: ...\n    @property\n    def tool_call_start(self) -> str | None: ...\n    @property\n    def tool_call_end(self) -> str | None: ...\n    @property\n    def detokenizer(self) -> NaiveStreamingDetokenizer:\n        \"\"\"Get a stateful streaming detokenizer.\"\"\"\n\n    def __getattr__(self, attr: str) -> Any: ...\n    def __setattr__(self, attr: str, value: Any) -> None: ...\n\nclass NewlineTokenizer(PreTrainedTokenizerFast):\n    \"\"\"A tokenizer that replaces newlines with <n> and <n> with new line.\"\"\"\n    def __init__(self, *args, **kwargs) -> None: ...\n    def encode(self, text, **kwargs):  # -> list[int]:\n        ...\n    def encode_batch(self, texts, **kwargs): ...\n    def decode(self, *args, **kwargs):  # -> str:\n        ...\n    def batch_decode(self, *args, **kwargs):  # -> list[str]:\n        ...\n\ndef load(\n    model_path: Path,\n    tokenizer_config_extra: dict[str, Any] | None = None,\n    eos_token_ids: list[int] | int | None = None,\n) -> TokenizerWrapper:\n    \"\"\"Load a huggingface tokenizer and try to infer the type of streaming\n    detokenizer to use.\n\n    Note, to use a fast streaming tokenizer, pass a local file path rather than\n    a Hugging Face repo ID.\n    \"\"\"\n\n# Alias for backward compatibility\nload_tokenizer = load\n\ndef no_bos_or_eos(sequence: list[int], bos: int, eos: int) -> list[int]: ...\n"
  },
  {
    "path": ".mlx_typings/mlx_lm/utils.pyi",
    "content": "\"\"\"\nThis type stub file was generated by pyright.\n\"\"\"\n\nimport os\nfrom pathlib import Path\nfrom typing import Any, Callable, Dict, Optional, Tuple, Type, Union\n\nimport mlx.nn as nn\nfrom transformers.utils.auto_docstring import ModelArgs\n\nfrom .tokenizer_utils import TokenizerWrapper\n\nif os.getenv(\"MLXLM_USE_MODELSCOPE\", \"False\").lower() == \"true\": ...\nelse: ...\nMODEL_REMAPPING = ...\nMAX_FILE_SIZE_GB = ...\n\ndef compute_bits_per_weight(model): ...\ndef hf_repo_to_path(hf_repo):  # -> Path:\n    ...\ndef load_config(model_path: Path) -> dict: ...\ndef load_model(\n    model_path: Path,\n    lazy: bool = False,\n    strict: bool = True,\n    model_config: dict[str, Any] = {},\n    get_model_classes: Callable[\n        [dict[str, Any]], Tuple[Type[nn.Module], Type[ModelArgs]]\n    ] = ...,\n) -> Tuple[nn.Module, dict[str, Any]]:\n    \"\"\"\n    Load and initialize the model from a given path.\n\n    Args:\n        model_path (Path): The path to load the model from.\n        lazy (bool): If False eval the model parameters to make sure they are\n            loaded in memory before returning, otherwise they will be loaded\n            when needed. Default: ``False``\n        strict (bool): Whether or not to raise an exception if weights don't\n            match. Default: ``True``\n        model_config (dict, optional): Optional configuration parameters for the\n            model. Defaults to an empty dictionary.\n        get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):\n            A function that returns the model class and model args class given a config.\n            Defaults to the ``_get_classes`` function.\n\n    Returns:\n        Tuple[nn.Module, dict[str, Any]]: The loaded and initialized model and config.\n\n    Raises:\n        FileNotFoundError: If the weight files (.safetensors) are not found.\n        ValueError: If the model class or args class are not found or cannot be instantiated.\n    \"\"\"\n\ndef load(\n    path_or_hf_repo: str,\n    tokenizer_config=...,\n    model_config=...,\n    adapter_path: Optional[str] = ...,\n    lazy: bool = ...,\n    return_config: bool = ...,\n    revision: str = ...,\n) -> Union[\n    Tuple[nn.Module, TokenizerWrapper],\n    Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]],\n]:\n    \"\"\"\n    Load the model and tokenizer from a given path or a huggingface repository.\n\n    Args:\n        path_or_hf_repo (Path): The path or the huggingface repository to load the model from.\n        tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.\n            Defaults to an empty dictionary.\n        model_config(dict, optional): Configuration parameters specifically for the model.\n            Defaults to an empty dictionary.\n        adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers\n            to the model. Default: ``None``.\n        lazy (bool): If ``False`` eval the model parameters to make sure they are\n            loaded in memory before returning, otherwise they will be loaded\n            when needed. Default: ``False``\n        return_config (bool: If ``True`` return the model config as the last item..\n        revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.\n    Returns:\n        Union[Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]]]:\n            A tuple containing the loaded model, tokenizer and, if requested, the model config.\n\n    Raises:\n        FileNotFoundError: If config file or safetensors are not found.\n        ValueError: If model class or args class are not found.\n    \"\"\"\n\ndef make_shards(weights: dict, max_file_size_gb: int = ...) -> list:\n    \"\"\"\n    Splits the weights into smaller shards.\n\n    Args:\n        weights (dict): Model weights.\n        max_file_size_gb (int): Maximum size of each shard in gigabytes.\n\n    Returns:\n        list: List of weight shards.\n    \"\"\"\n\ndef create_model_card(\n    path: Union[str, Path], hf_path: Union[str, Path, None]\n):  # -> None:\n    \"\"\"\n    Uploads the model to Hugging Face hub.\n\n    Args:\n        path (Union[str, Path]): Local path to the model.\n        hf_path (Union[str, Path, None]): Path to the original Hugging Face model.\n    \"\"\"\n\ndef upload_to_hub(path: str, upload_repo: str):  # -> None:\n    \"\"\"\n    Uploads the model to Hugging Face hub.\n\n    Args:\n        path (str): Local path to the model.\n        upload_repo (str): Name of the HF repo to upload to.\n    \"\"\"\n\ndef save_model(\n    save_path: Union[str, Path], model: nn.Module, *, donate_model: bool = ...\n) -> None:\n    \"\"\"Save model weights and metadata index into specified directory.\"\"\"\n\ndef quantize_model(\n    model: nn.Module,\n    config: dict,\n    group_size: int,\n    bits: int,\n    mode: str = ...,\n    quant_predicate: Optional[Callable[[str, nn.Module], Union[bool, dict]]] = ...,\n) -> Tuple[nn.Module, dict]:\n    \"\"\"\n    Applies quantization to the model weights.\n\n    Args:\n        model (nn.Module): The model to be quantized.\n        config (dict): Model configuration.\n        group_size (int): Group size for quantization.\n        bits (int): Bits per weight for quantization.\n        mode (str): The quantization mode.\n        quant_predicate (Callable): A callable that decides how to quantize\n          each layer based on the path. Accepts the layer `path` and the\n          `module`. Returns either a bool to signify quantize/no quantize or\n          a dict of quantization parameters to pass to `to_quantized`.\n\n    Returns:\n        Tuple: Tuple containing quantized model and config.\n    \"\"\"\n\ndef save_config(config: dict, config_path: Union[str, Path]) -> None:\n    \"\"\"Save the model configuration to the ``config_path``.\n\n    The final configuration will be sorted before saving for better readability.\n\n    Args:\n        config (dict): The model configuration.\n        config_path (Union[str, Path]): Model configuration file path.\n    \"\"\"\n\ndef save(\n    dst_path: Union[str, Path],\n    src_path_or_repo: Union[str, Path],\n    model: nn.Module,\n    tokenizer: TokenizerWrapper,\n    config: Dict[str, Any],\n    donate_model: bool = ...,\n):  # -> None:\n    ...\ndef common_prefix_len(list1, list2):  # -> int:\n    \"\"\"\n    Calculates the length of the common prefix of two lists.\n\n    Args:\n        list1: The first list of strings.\n        list2: The second list of strings.\n\n    Returns:\n        The length of the common prefix. Returns 0 if lists are empty\n        or do not match at the first element.\n    \"\"\"\n\ndef does_model_support_input_embeddings(model: nn.Module) -> bool:\n    \"\"\"\n    Check if the model supports input_embeddings in its call signature.\n    Args:\n        model (nn.Module): The model to check.\n    Returns:\n        bool: True if the model supports input_embeddings, False otherwise.\n    \"\"\"\n"
  },
  {
    "path": ".python-version",
    "content": "3.13\n"
  },
  {
    "path": ".swift-format",
    "content": "{\n  \"version\": 1,\n  \"indentation\": {\n    \"spaces\": 4\n  }\n}\n"
  },
  {
    "path": ".vscode/extensions.json",
    "content": "{\n    \"recommendations\": [\n        \"detachhead.basedpyright\",\n        \"ms-python.python\"\n    ],\n    \"unwantedRecommendations\": [\n        \"ms-python.vscode-pylance\",\n        \"ms-python.pyright\",\n        \"ms-python.mypy-type-checker\"\n    ]\n}"
  },
  {
    "path": ".vscode/settings.json",
    "content": "{\n    \"basedpyright.importStrategy\": \"fromEnvironment\"\n}"
  },
  {
    "path": ".zed/settings.json",
    "content": "// Folder-specific settings\n//\n// For a full list of overridable settings, and general information on folder-specific settings,\n// see the documentation: https://zed.dev/docs/configuring-zed#settings-files\n{\n  \"lsp\": {\n    \"nix_python\": {\n      \"binary\": {\n        \"path\": \"nix\",\n        \"arguments\": [\n          \"run\",\n          \"--quiet\",\n          \"--no-warn-dirty\",\n          \"--no-allow-import-from-derivation\",\n          \"--print-build-logs\",\n          \"never\",\n          \"${projectRoot}#python-lsp\",\n          \"--\",\n          \"--stdio\"\n        ]\n      }\n    }\n  },\n  \"languages\": {\n    \"Python\": {\n      \"language_servers\": [\"nix_python\"]\n    }\n  }\n}\n"
  },
  {
    "path": "AGENTS.md",
    "content": "# AGENTS.md\n\nThis file provides guidance to AI coding agents when working with code in this repository.\n\n## Project Overview\n\nexo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and libp2p for peer-to-peer networking.\n\n## Build & Run Commands\n\n```bash\n# Build the dashboard (required before running exo)\ncd dashboard && npm install && npm run build && cd ..\n\n# Run exo (starts both master and worker with API at http://localhost:52415)\nuv run exo\n\n# Run with verbose logging\nuv run exo -v   # or -vv for more verbose\n\n# Run tests (excludes slow tests by default)\nuv run pytest\n\n# Run all tests including slow tests\nuv run pytest -m \"\"\n\n# Run a specific test file\nuv run pytest src/exo/shared/tests/test_election.py\n\n# Run a specific test function\nuv run pytest src/exo/shared/tests/test_election.py::test_function_name\n\n# Type checking (strict mode)\nuv run basedpyright\n\n# Linting\nuv run ruff check\n\n# Format code (using nix)\nnix fmt\n```\n\n## Pre-Commit Checks (REQUIRED)\n\n**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**\n\n```bash\n# 1. Type checking - MUST pass with 0 errors\nuv run basedpyright\n\n# 2. Linting - MUST pass\nuv run ruff check\n\n# 3. Formatting - MUST be applied\nnix fmt\n\n# 4. Tests - MUST pass\nuv run pytest\n```\n\nRun all checks in sequence:\n```bash\nuv run basedpyright && uv run ruff check && nix fmt && uv run pytest\n```\n\nIf `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.\n\n## Architecture\n\n### Node Composition\nA single exo `Node` (src/exo/main.py) runs multiple components:\n- **Router**: libp2p-based pub/sub messaging via Rust bindings (exo_pyo3_bindings)\n- **Worker**: Handles inference tasks, downloads models, manages runner processes\n- **Master**: Coordinates cluster state, places model instances across nodes\n- **Election**: Bully algorithm for master election\n- **API**: FastAPI server for OpenAI-compatible chat completions\n\n### Message Flow\nComponents communicate via typed pub/sub topics (src/exo/routing/topics.py):\n- `GLOBAL_EVENTS`: Master broadcasts indexed events to all workers\n- `LOCAL_EVENTS`: Workers send events to master for indexing\n- `COMMANDS`: Workers/API send commands to master\n- `ELECTION_MESSAGES`: Election protocol messages\n- `CONNECTION_MESSAGES`: libp2p connection updates\n\n### Event Sourcing\nThe system uses event sourcing for state management:\n- `State` (src/exo/shared/types/state.py): Immutable state object\n- `apply()` (src/exo/shared/apply.py): Pure function that applies events to state\n- Master indexes events and broadcasts; workers apply indexed events\n\n### Key Type Hierarchy\n- `src/exo/shared/types/`: Pydantic models for all shared types\n  - `events.py`: Event types (discriminated union)\n  - `commands.py`: Command types\n  - `tasks.py`: Task types for worker execution\n  - `state.py`: Cluster state model\n\n### Rust Components\nRust code in `rust/` provides:\n- `networking`: libp2p networking (gossipsub, peer discovery)\n- `exo_pyo3_bindings`: PyO3 bindings exposing Rust to Python\n- `system_custodian`: System-level operations\n\n### Dashboard\nSvelte 5 + TypeScript frontend in `dashboard/`. Build output goes to `dashboard/build/` and is served by the API.\n\n## Code Style Requirements\n\nFrom .cursorrules:\n- Strict, exhaustive typing - never bypass the type-checker\n- Use `Literal[...]` for enum-like sets, `typing.NewType` for primitives\n- Pydantic models with `frozen=True` and `strict=True`\n- Pure functions with injectable effect handlers for side-effects\n- Descriptive names - no abbreviations or 3-letter acronyms\n- Catch exceptions only where you can handle them meaningfully\n- Use `@final` and immutability wherever applicable\n\n## Testing\n\nTests use pytest-asyncio with `asyncio_mode = \"auto\"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.\n\n## Dashboard UI Testing & Screenshots\n\n### Building and Running the Dashboard\n```bash\n# Build the dashboard (must be done before running exo)\ncd dashboard && npm install && npm run build && cd ..\n\n# Start exo (serves the dashboard at http://localhost:52415)\nuv run exo &\nsleep 8  # Wait for server to start\n```\n\n### Taking Headless Screenshots with Playwright\nUse Playwright with headless Chromium for programmatic screenshots — no manual browser interaction needed.\n\n**Setup (one-time):**\n```bash\nnpx --yes playwright install chromium\ncd /tmp && npm init -y && npm install playwright\n```\n\n**Taking screenshots:**\n```javascript\n// Run from /tmp where playwright is installed: cd /tmp && node -e \"...\"\nconst { chromium } = require('playwright');\n(async () => {\n  const browser = await chromium.launch({ headless: true });\n  const page = await browser.newPage({ viewport: { width: 1280, height: 800 } });\n  await page.goto('http://localhost:52415', { waitUntil: 'networkidle' });\n  await page.waitForTimeout(2000);\n\n  // Inject test data into localStorage if needed (e.g., recent models)\n  await page.evaluate(() => {\n    localStorage.setItem('exo-recent-models', JSON.stringify([\n      { modelId: 'mlx-community/Qwen3-30B-A3B-4bit', launchedAt: Date.now() },\n    ]));\n  });\n  await page.reload({ waitUntil: 'networkidle' });\n  await page.waitForTimeout(2000);\n\n  // Interact with UI elements\n  await page.locator('text=SELECT MODEL').click();\n  await page.waitForTimeout(1000);\n\n  // Take screenshot\n  await page.screenshot({ path: '/tmp/screenshot.png', fullPage: false });\n  await browser.close();\n})();\n```\n\n### Uploading Images to GitHub PRs\nGitHub's API doesn't support direct image upload for PR comments. Workaround:\n\n1. **Commit images to the branch** (temporarily):\n   ```bash\n   cp /tmp/screenshot.png .\n   git add screenshot.png\n   git commit -m \"temp: add screenshots for PR\"\n   git push origin <branch>\n   COMMIT_SHA=$(git rev-parse HEAD)\n   ```\n\n2. **Post PR comment** referencing the raw image URL (uses permanent commit SHA so images survive deletion):\n   ```bash\n   gh pr comment <PR_NUMBER> --body \"![Screenshot](https://raw.githubusercontent.com/exo-explore/exo/${COMMIT_SHA}/screenshot.png)\"\n   ```\n\n3. **Remove the images** from the branch:\n   ```bash\n   git rm screenshot.png\n   git commit -m \"chore: remove temporary screenshot files\"\n   git push origin <branch>\n   ```\n   The images still render in the PR comment because they reference the permanent commit SHA.\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to EXO\n\nThank you for your interest in contributing to EXO!\n\n## Getting Started\n\nTo run EXO from source:\n\n**Prerequisites:**\n- [uv](https://github.com/astral-sh/uv) (for Python dependency management)\n  ```bash\n  brew install uv\n  ```\n- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)\n  ```bash\n  brew install macmon\n  ```\n\n```bash\ngit clone https://github.com/exo-explore/exo.git\ncd exo/dashboard\nnpm install && npm run build && cd ..\nuv run exo\n```\n\n## Development\n\nEXO is built with a mix of Rust, Python, and TypeScript (Svelte for the dashboard), and the codebase is actively evolving. Before starting work:\n\n- Pull the latest source to ensure you're working with the most recent code\n- Keep your changes focused - implement one feature or fix per pull request\n- Avoid combining unrelated changes, even if they seem small\n\nThis makes reviews faster and helps us maintain code quality as the project evolves.\n\n## Code Style\n\nWrite pure functions where possible. When adding new code, prefer Rust unless there's a good reason otherwise. Leverage the type systems available to you - Rust's type system, Python type hints, and TypeScript types. Comments should explain why you're doing something, not what the code does - especially for non-obvious decisions.\n\nRun `nix fmt` to auto-format your code before submitting.\n\n## Model Cards\n\nEXO uses TOML-based model cards to define model metadata and capabilities. Model cards are stored in:\n- `resources/inference_model_cards/` for text generation models\n- `resources/image_model_cards/` for image generation models\n- `~/.exo/custom_model_cards/` for user-added custom models\n\n### Adding a Model Card\n\nTo add a new model, create a TOML file with the following structure:\n\n```toml\nmodel_id = \"mlx-community/Llama-3.2-1B-Instruct-4bit\"\nn_layers = 16\nhidden_size = 2048\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"4bit\"\nbase_model = \"Llama 3.2 1B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 729808896\n```\n\n### Required Fields\n\n- `model_id`: Hugging Face model identifier\n- `n_layers`: Number of transformer layers\n- `hidden_size`: Hidden dimension size\n- `supports_tensor`: Whether the model supports tensor parallelism\n- `tasks`: List of supported tasks (`TextGeneration`, `TextToImage`, `ImageToImage`)\n- `family`: Model family (e.g., \"llama\", \"deepseek\", \"qwen\")\n- `quantization`: Quantization level (e.g., \"4bit\", \"8bit\", \"bf16\")\n- `base_model`: Human-readable base model name\n- `capabilities`: List of capabilities (e.g., `[\"text\"]`, `[\"text\", \"thinking\"]`)\n\n### Optional Fields\n\n- `components`: For multi-component models (like image models with separate text encoders and transformers)\n- `uses_cfg`: Whether the model uses classifier-free guidance (for image models)\n- `trust_remote_code`: Whether to allow remote code execution (defaults to `false` for security)\n\n### Capabilities\n\nThe `capabilities` field defines what the model can do:\n- `text`: Standard text generation\n- `thinking`: Model supports chain-of-thought reasoning\n- `thinking_toggle`: Thinking can be enabled/disabled via `enable_thinking` parameter\n- `image_edit`: Model supports image-to-image editing (FLUX.1-Kontext)\n\n### Security Note\n\nBy default, `trust_remote_code` is set to `false` for security. Only enable it if the model explicitly requires remote code execution from the Hugging Face hub.\n\n## API Adapters\n\nEXO supports multiple API formats through an adapter pattern. Adapters convert API-specific request formats to the internal `TextGenerationTaskParams` format and convert internal token chunks back to API-specific responses.\n\n### Adapter Architecture\n\nAll adapters live in `src/exo/master/adapters/` and follow the same pattern:\n\n1. Convert API-specific requests to `TextGenerationTaskParams`\n2. Handle both streaming and non-streaming response generation\n3. Convert internal `TokenChunk` objects to API-specific formats\n4. Manage error handling and edge cases\n\n### Existing Adapters\n\n- `chat_completions.py`: OpenAI Chat Completions API\n- `claude.py`: Anthropic Claude Messages API\n- `responses.py`: OpenAI Responses API\n- `ollama.py`: Ollama API (for OpenWebUI compatibility)\n\n### Adding a New API Adapter\n\nTo add support for a new API format:\n\n1. Create a new adapter file in `src/exo/master/adapters/`\n2. Implement a request conversion function:\n   ```python\n   def your_api_request_to_text_generation(\n       request: YourAPIRequest,\n   ) -> TextGenerationTaskParams:\n       # Convert API request to internal format\n       pass\n   ```\n3. Implement streaming response generation:\n   ```python\n   async def generate_your_api_stream(\n       command_id: CommandId,\n       chunk_stream: AsyncGenerator[TokenChunk | ErrorChunk | ToolCallChunk, None],\n   ) -> AsyncGenerator[str, None]:\n       # Convert internal chunks to API-specific streaming format\n       pass\n   ```\n4. Implement non-streaming response collection:\n   ```python\n   async def collect_your_api_response(\n       command_id: CommandId,\n       chunk_stream: AsyncGenerator[TokenChunk | ErrorChunk | ToolCallChunk, None],\n   ) -> AsyncGenerator[str]:\n       # Collect all chunks and return single response\n       pass\n   ```\n5. Register the adapter endpoints in `src/exo/master/api.py`\n\nThe adapter pattern keeps API-specific logic isolated from core inference systems. Internal systems (worker, runner, event sourcing) only see `TextGenerationTaskParams` and `TokenChunk` objects - no API-specific types cross the adapter boundary.\n\nFor detailed API documentation, see [docs/api.md](docs/api.md).\n\n## Testing\n\nEXO relies heavily on manual testing at this point in the project, but this is evolving. Before submitting a change, test both before and after to demonstrate how your change improves behavior. Do the best you can with the hardware you have available - if you need help testing, ask and we'll do our best to assist. Add automated tests where possible - we're actively working to substantially improve our automated testing story.\n\n## Submitting Changes\n\n1. Fork the repository\n2. Create a feature branch (`git checkout -b feature/your-feature`)\n3. Commit your changes (`git commit -am 'Add some feature'`)\n4. Push to the branch (`git push origin feature/your-feature`)\n5. Open a Pull Request and follow the PR template\n\n## Reporting Issues\n\nIf you find a bug or have a feature request, please open an issue on GitHub with:\n- A clear description of the problem or feature\n- Steps to reproduce (for bugs)\n- Expected vs actual behavior\n- Your environment (macOS version, hardware, etc.)\n\n## Questions?\n\nJoin our community:\n- [X](https://x.com/exolabs)\n"
  },
  {
    "path": "Cargo.toml",
    "content": "[workspace]\nresolver = \"3\"\nmembers = [\"rust/networking\", \"rust/exo_pyo3_bindings\", \"rust/util\"]\n\n[workspace.package]\nversion = \"0.0.1\"\nedition = \"2024\"\n\n[profile.dev]\nopt-level = 1\ndebug = true\n\n[profile.release]\nopt-level = 3\n\n# Common shared dependendencies configured once at the workspace\n# level, to be re-used more easily across workspace member crates.\n#\n# Common configurations include versions, paths, features, etc.\n[workspace.dependencies]\n## Crate members as common dependencies\nnetworking = { path = \"rust/networking\" }\nutil = { path = \"rust/util\" }\n\n# Macro dependecies\nextend = \"1.2\"\ndelegate = \"0.13\"\n\n# Utility dependencies\nkeccak-const = \"0.2\"\n\n# Async dependencies\nasync-stream = \"0.3\"\ntokio = \"1.46\"\nfutures-lite = \"2.6.1\"\nfutures-timer = \"3.0\"\n\n# Data structures\neither = \"1.15\"\n\n# Tracing/logging\nlog = \"0.4\"\n\n# networking\nlibp2p = \"0.56\"\nlibp2p-tcp = \"0.44\"\n\n[workspace.lints.rust]\nstatic_mut_refs = \"warn\"      # Or use \"warn\" instead of deny\nincomplete_features = \"allow\"\n\n# Clippy's lint category level configurations;\n# every member crate needs to inherit these by adding\n#\n#     ```toml\n#     [lints]\n#     workspace = true\n#     ```\n#\n# to their `Cargo.toml` files\n[workspace.lints.clippy]\n# Clippy lint categories meant to be enabled all at once\ncorrectness = { level = \"deny\", priority = -1 }\nsuspicious = { level = \"warn\", priority = -1 }\nstyle = { level = \"warn\", priority = -1 }\ncomplexity = { level = \"warn\", priority = -1 }\nperf = { level = \"warn\", priority = -1 }\npedantic = { level = \"warn\", priority = -1 }\nnursery = { level = \"warn\", priority = -1 }\ncargo = { level = \"warn\", priority = -1 }\n\n# Individual Clippy lints from the `restriction` category\narithmetic_side_effects = \"warn\"\nas_conversions = \"warn\"\nassertions_on_result_states = \"warn\"\nclone_on_ref_ptr = \"warn\"\ndecimal_literal_representation = \"warn\"\ndefault_union_representation = \"warn\"\nderef_by_slicing = \"warn\"\ndisallowed_script_idents = \"deny\"\nelse_if_without_else = \"warn\"\nempty_enum_variants_with_brackets = \"warn\"\nempty_structs_with_brackets = \"warn\"\nerror_impl_error = \"warn\"\nexit = \"deny\"\nexpect_used = \"warn\"\nfloat_cmp_const = \"warn\"\nget_unwrap = \"warn\"\nif_then_some_else_none = \"warn\"\nimpl_trait_in_params = \"warn\"\nindexing_slicing = \"warn\"\ninfinite_loop = \"warn\"\nlet_underscore_must_use = \"warn\"\nlet_underscore_untyped = \"warn\"\nlossy_float_literal = \"warn\"\nmem_forget = \"warn\"\nmissing_inline_in_public_items = \"warn\"\nmultiple_inherent_impl = \"warn\"\nmultiple_unsafe_ops_per_block = \"warn\"\nmutex_atomic = \"warn\"\nnon_zero_suggestions = \"warn\"\npanic = \"warn\"\npartial_pub_fields = \"warn\"\npattern_type_mismatch = \"warn\"\npub_without_shorthand = \"warn\"\nrc_buffer = \"warn\"\nrc_mutex = \"warn\"\nredundant_type_annotations = \"warn\"\nrenamed_function_params = \"warn\"\nrest_pat_in_fully_bound_structs = \"warn\"\nsame_name_method = \"warn\"\nself_named_module_files = \"deny\"\nsemicolon_inside_block = \"warn\"\nshadow_same = \"warn\"\nshadow_unrelated = \"warn\"\nstr_to_string = \"warn\"\nstring_add = \"warn\"\nstring_lit_chars_any = \"warn\"\nstring_to_string = \"warn\"\ntests_outside_test_module = \"warn\"\ntodo = \"warn\"\ntry_err = \"warn\"\nundocumented_unsafe_blocks = \"warn\"\nunnecessary_safety_comment = \"warn\"\nunnecessary_safety_doc = \"warn\"\nunneeded_field_pattern = \"warn\"\nunseparated_literal_suffix = \"warn\"\nunused_result_ok = \"warn\"\nunused_trait_names = \"warn\"\nunwrap_used = \"warn\"\nverbose_file_reads = \"warn\"\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2025 Exo Technologies Ltd\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n"
  },
  {
    "path": "MISSED_THINGS.md",
    "content": "# Missed things\n[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py\n[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).\n[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).\n[X] Fetching download status of all models on start\n[X] Deduplication of tasks in plan_step.\n[X] resolve_allow_patterns should just be wildcard now.\n[X] no mx_barrier in genreate.py mlx_generate at the end.\n[] cache assertion not needed in auto_parallel.py PipelineLastLayer.\n[X] GPTOSS support dropped in auto_parallel.py.\n[X] sharding changed \"all-to-sharded\" became _all_to_sharded in auto_parallel.py.\n[X] same as above with \"sharded-to-all\" became _sharded_to_all in auto_parallel.py.\n[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.\n[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.\n[X] KV_CACHE_BITS should be None to disable quantized KV cache.\n[X] Dropped _set_nofile_limit in utils_mlx.py.\n[X] We have group optional in load_mlx_items in utils_mlx.py.\n[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.\n[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.\n[X] We put cache limit back in utils_mlx.py.\n[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?\n[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)\n[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.\n[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.\n[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).\n[] logger.warning(\"You have likely selected ibv for a single node instance; falling back to MlxRing\") was changed to debug. That will spam this warning since it happens every time we query instance previews.\n[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).\n\n\n\n[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).\n[X] Fetching download status of all models on start\n[X] Deduplication of tasks in plan_step.\n[X] resolve_allow_patterns should just be wildcard now.\n[X] KV_CACHE_BITS should be None to disable quantized KV cache.\n[X] We put cache limit back in utils_mlx.py.\n[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.\n[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).\n[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).\n\n\n"
  },
  {
    "path": "PLATFORMS.md",
    "content": "# EXO Platform support (partial roadmap)\n\n## Tier 1 support - tested and maintained\n\nApple Silicon MacOS\n- Mac Studio: M3 Ultra\n- Mac Mini: M4 Pro\n- Macbook Pro: M5, M4 Max\n\n## Tier 2 support - checked occasionally, should run without crashing\n\n\n## Tier 3 support - minimal support and testing, but no theoretical reason it shouldnt work\n\n\n# Planned\n\n## Tier 1\n\nLinux CUDA Support\n- Nvidia DGX Spark\n\nLinux CPU Support\n\n## Tier 2\n\nLinux Vulkan Support -- depends heavily on ecosystem\n- Framework Desktop\n\nLinux CUDA Support -- depends heavily on ecosystem\n- Framework Desktop\n\n## Longer term!\n\nWindows CUDA Support\n\nWindows CPU Support\n\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n<picture>\n  <source media=\"(prefers-color-scheme: light)\" srcset=\"/docs/imgs/exo-logo-black-bg.jpg\">\n  <img alt=\"exo logo\" src=\"/docs/imgs/exo-logo-transparent.png\" width=\"50%\" height=\"50%\">\n</picture>\n\nexo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).\n\n<p align=\"center\">\n  <a href=\"https://discord.gg/TJ4P57arEm\" target=\"_blank\" rel=\"noopener noreferrer\"><img src=\"https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white\" alt=\"Discord\"></a>\n  <a href=\"https://x.com/exolabs\" target=\"_blank\" rel=\"noopener noreferrer\"><img src=\"https://img.shields.io/twitter/follow/exolabs?style=social\" alt=\"X\"></a>\n  <a href=\"https://www.apache.org/licenses/LICENSE-2.0.html\" target=\"_blank\" rel=\"noopener noreferrer\"><img src=\"https://img.shields.io/badge/License-Apache2.0-blue.svg\" alt=\"License: Apache-2.0\"></a>\n</p>\n\n</div>\n\n---\n\nexo connects all your devices into an AI cluster. Not only does exo enable running models larger than would fit on a single device, but with [day-0 support for RDMA over Thunderbolt](https://x.com/exolabs/status/2001817749744476256?s=20), makes models run faster as you add more devices.\n\n## Features\n\n- **Automatic Device Discovery**: Devices running exo automatically discover each other - no manual configuration.\n- **RDMA over Thunderbolt**: exo ships with [day-0 support for RDMA over Thunderbolt 5](https://x.com/exolabs/status/2001817749744476256?s=20), enabling 99% reduction in latency between devices.\n- **Topology-Aware Auto Parallel**: exo figures out the best way to split your model across all available devices based on a realtime view of your device topology. It takes into account device resources and network latency/bandwidth between each link.\n- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.\n- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.\n- **Multiple API Compatibility**: Compatible with OpenAI Chat Completions API, Claude Messages API, OpenAI Responses API, and Ollama API - use your existing tools and clients.\n- **Custom Model Support**: Load custom models from HuggingFace hub to expand the range of available models.\n\n## Dashboard\n\nexo includes a built-in dashboard for managing your cluster and chatting with models.\n\n<p align=\"center\">\n  <img src=\"docs/imgs/dashboard-cluster-view.png\" alt=\"exo dashboard - cluster view showing 4 x M3 Ultra Mac Studio with DeepSeek v3.1 and Kimi-K2-Thinking loaded\" width=\"80%\" />\n</p>\n<p align=\"center\"><em>4 × 512GB M3 Ultra Mac Studio running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)</em></p>\n\n## Benchmarks\n\n<details>\n  <summary>Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>\n  <img src=\"docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-1-qwen3-235b.jpeg\" alt=\"Benchmark - Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA\" width=\"80%\" />\n  <p>\n    <strong>Source:</strong> <a href=\"https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5\">Jeff Geerling: 15 TB VRAM on Mac Studio – RDMA over Thunderbolt 5</a>\n  </p>\n</details>\n\n<details>\n  <summary>DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>\n  <img src=\"docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-2-deepseek-3.1-671b.jpeg\" alt=\"Benchmark - DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA\" width=\"80%\" />\n  <p>\n    <strong>Source:</strong> <a href=\"https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5\">Jeff Geerling: 15 TB VRAM on Mac Studio – RDMA over Thunderbolt 5</a>\n  </p>\n</details>\n\n<details>\n  <summary>Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>\n  <img src=\"docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-3-kimi-k2-thinking.jpeg\" alt=\"Benchmark - Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA\" width=\"80%\" />\n  <p>\n    <strong>Source:</strong> <a href=\"https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5\">Jeff Geerling: 15 TB VRAM on Mac Studio – RDMA over Thunderbolt 5</a>\n  </p>\n</details>\n\n---\n\n## Quick Start\n\nDevices running exo automatically discover each other, without needing any manual configuration. Each device provides an API and a dashboard for interacting with your cluster (runs at `http://localhost:52415`).\n\nThere are two ways to run exo:\n\n### Run from Source (macOS)\n\nIf you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly:\n\n```bash\nnix run .#exo\n```\n\n**Note:** To accept the Cachix binary cache (and avoid the Xcode Metal ToolChain), add to `/etc/nix/nix.conf`:\n```\ntrusted-users = root    (or your username)\nexperimental-features = nix-command flakes\n```\nThen restart the Nix daemon: `sudo launchctl kickstart -k system/org.nixos.nix-daemon`\n\n**Prerequisites:**\n- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)\n- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)\n\n  ```bash\n  /bin/bash -c \"$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)\"\n  ```\n- [uv](https://github.com/astral-sh/uv) (for Python dependency management)\n- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)\n- [node](https://github.com/nodejs/node) (for building the dashboard)\n\n  ```bash\n  brew install uv macmon node\n  ```\n- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)\n\n  ```bash\n  curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh\n  rustup toolchain install nightly\n  ```\n\nClone the repo, build the dashboard, and run exo:\n\n```bash\n# Clone exo\ngit clone https://github.com/exo-explore/exo\n\n# Build dashboard\ncd exo/dashboard && npm install && npm run build && cd ..\n\n# Run exo\nuv run exo\n```\n\nThis starts the exo dashboard and API at http://localhost:52415/\n\n\n*Please view the section on RDMA to enable this feature on MacOS >=26.2!*\n\n\n### Run from Source (Linux)\n\n**Prerequisites:**\n\n- [uv](https://github.com/astral-sh/uv) (for Python dependency management)\n- [node](https://github.com/nodejs/node) (for building the dashboard) - version 18 or higher\n- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)\n\n**Installation methods:**\n\n**Option 1: Using system package manager (Ubuntu/Debian example):**\n```bash\n# Install Node.js and npm\nsudo apt update\nsudo apt install nodejs npm\n\n# Install uv\ncurl -LsSf https://astral.sh/uv/install.sh | sh\n\n# Install Rust (using rustup)\ncurl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh\nrustup toolchain install nightly\n```\n\n**Option 2: Using Homebrew on Linux (if preferred):**\n```bash\n# Install Homebrew on Linux\n/bin/bash -c \"$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)\"\n\n# Install dependencies\nbrew install uv node\n\n# Install Rust (using rustup)\ncurl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh\nrustup toolchain install nightly\n```\n\n**Note:** The `macmon` package is macOS-only and not required for Linux.\n\nClone the repo, build the dashboard, and run exo:\n\n```bash\n# Clone exo\ngit clone https://github.com/exo-explore/exo\n\n# Build dashboard\ncd exo/dashboard && npm install && npm run build && cd ..\n\n# Run exo\nuv run exo\n```\n\nThis starts the exo dashboard and API at http://localhost:52415/\n\n**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.\n\n**Configuration Options:**\n\n- `--no-worker`: Run exo without the worker component. Useful for coordinator-only nodes that handle networking and orchestration but don't execute inference tasks. This is helpful for machines without sufficient GPU resources but with good network connectivity.\n\n  ```bash\n  uv run exo --no-worker\n  ```\n\n**File Locations (Linux):**\n\nexo follows the [XDG Base Directory Specification](https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html) on Linux:\n\n- **Configuration files**: `~/.config/exo/` (or `$XDG_CONFIG_HOME/exo/`)\n- **Data files**: `~/.local/share/exo/` (or `$XDG_DATA_HOME/exo/`)\n- **Cache files**: `~/.cache/exo/` (or `$XDG_CACHE_HOME/exo/`)\n- **Log files**: `~/.cache/exo/exo_log/` (with automatic log rotation)\n- **Custom model cards**: `~/.local/share/exo/custom_model_cards/`\n\nYou can override these locations by setting the corresponding XDG environment variables.\n\n### macOS App\n\nexo ships a macOS app that runs in the background on your Mac.\n\n<img src=\"docs/imgs/macos-app-one-macbook.png\" alt=\"exo macOS App - running on a MacBook\" width=\"35%\" />\n\nThe macOS app requires macOS Tahoe 26.2 or later.\n\nDownload the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-latest.dmg).\n\nThe app will ask for permission to modify system settings and install a new Network profile. Improvements to this are being worked on.\n\n**Custom Namespace for Cluster Isolation:**\n\nThe macOS app includes a custom namespace feature that allows you to isolate your exo cluster from others on the same network. This is configured through the `EXO_LIBP2P_NAMESPACE` setting:\n\n- **Use cases**:\n  - Running multiple separate exo clusters on the same network\n  - Isolating development/testing clusters from production clusters\n  - Preventing accidental cluster joining\n\n- **Configuration**: Access this setting in the app's Advanced settings (or set the `EXO_LIBP2P_NAMESPACE` environment variable when running from source)\n\nThe namespace is logged on startup for debugging purposes.\n\n#### Uninstalling the macOS App\n\nThe recommended way to uninstall is through the app itself: click the menu bar icon → Advanced → Uninstall. This cleanly removes all system components.\n\nIf you've already deleted the app, you can run the standalone uninstaller script:\n\n```bash\nsudo ./app/EXO/uninstall-exo.sh\n```\n\nThis removes:\n- Network setup LaunchDaemon\n- Network configuration script\n- Log files\n- The \"exo\" network location\n\n**Note:** You'll need to manually remove EXO from Login Items in System Settings → General → Login Items.\n\n---\n\n### Enabling RDMA on macOS\n\nRDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).\n\nPlease refer to the caveats for immediate troubleshooting.\n\nTo enable RDMA on macOS, follow these steps:\n\n1. Shut down your Mac.\n2. Hold down the power button for 10 seconds until the boot menu appears.\n3. Select \"Options\" to enter Recovery mode.\n4. When the Recovery UI appears, open the Terminal from the Utilities menu.\n5. In the Terminal, type:\n   ```\n   rdma_ctl enable\n   ```\n   and press Enter.\n6. Reboot your Mac.\n\nAfter that, RDMA will be enabled in macOS and exo will take care of the rest.\n\n**Important Caveats**\n\n1. Devices that wish to be part of an RDMA cluster must be connected to all other devices in the cluster.\n2. The cables must support TB5.\n3. On a Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.\n4. If running from source, please use the script found at `tmp/set_rdma_network_config.sh`, which will disable Thunderbolt Bridge and set dhcp on each RDMA port.\n5. RDMA ports may be unable to discover each other on different versions of MacOS. Please ensure that OS versions match exactly (even beta version numbers) on all devices.\n\n---\n\n## Environment Variables\n\nexo supports several environment variables for configuration:\n\n| Variable | Description | Default |\n|----------|-------------|---------|\n| `EXO_MODELS_PATH` | Colon-separated paths to search for pre-downloaded models (e.g., on NFS mounts or shared storage) | None |\n| `EXO_MODELS_DIR` | Directory where exo downloads and stores models | `~/.local/share/exo/models` (Linux) or `~/.exo/models` (macOS) |\n| `EXO_OFFLINE` | Run without internet connection (uses only local models) | `false` |\n| `EXO_ENABLE_IMAGE_MODELS` | Enable image model support | `false` |\n| `EXO_LIBP2P_NAMESPACE` | Custom namespace for cluster isolation | None |\n| `EXO_FAST_SYNCH` | Control MLX_METAL_FAST_SYNCH behavior (for JACCL backend) | Auto |\n| `EXO_TRACING_ENABLED` | Enable distributed tracing for performance analysis | `false` |\n\n**Example usage:**\n\n```bash\n# Use pre-downloaded models from NFS mount\nEXO_MODELS_PATH=/mnt/nfs/models:/opt/ai-models uv run exo\n\n# Run in offline mode\nEXO_OFFLINE=true uv run exo\n\n# Enable image models\nEXO_ENABLE_IMAGE_MODELS=true uv run exo\n\n# Use custom namespace for cluster isolation\nEXO_LIBP2P_NAMESPACE=my-dev-cluster uv run exo\n```\n\n---\n\n### Using the API\n\nexo provides multiple API-compatible interfaces for maximum compatibility with existing tools:\n\n- **OpenAI Chat Completions API** - Compatible with OpenAI clients\n- **Claude Messages API** - Compatible with Anthropic's Claude format\n- **OpenAI Responses API** - Compatible with OpenAI's Responses format\n- **Ollama API** - Compatible with Ollama and tools like OpenWebUI\n\nIf you prefer to interact with exo via the API, here is an example creating an instance of a small model (`mlx-community/Llama-3.2-1B-Instruct-4bit`), sending a chat completions request and deleting the instance.\n\n---\n\n**1. Preview instance placements**\n\nThe `/instance/previews` endpoint will preview all valid placements for your model.\n\n```bash\ncurl \"http://localhost:52415/instance/previews?model_id=llama-3.2-1b\"\n```\n\nSample response:\n\n```json\n{\n  \"previews\": [\n    {\n      \"model_id\": \"mlx-community/Llama-3.2-1B-Instruct-4bit\",\n      \"sharding\": \"Pipeline\",\n      \"instance_meta\": \"MlxRing\",\n      \"instance\": {...},\n      \"memory_delta_by_node\": {\"local\": 729808896},\n      \"error\": null\n    }\n    // ...possibly more placements...\n  ]\n}\n```\n\nThis will return all valid placements for this model. Pick a placement that you like.\nTo pick the first one, pipe into `jq`:\n\n```bash\ncurl \"http://localhost:52415/instance/previews?model_id=llama-3.2-1b\" | jq -c '.previews[] | select(.error == null) | .instance' | head -n1\n```\n\n---\n\n**2. Create a model instance**\n\nSend a POST to `/instance` with your desired placement in the `instance` field (the full payload must match types as in `CreateInstanceParams`), which you can copy from step 1:\n\n```bash\ncurl -X POST http://localhost:52415/instance \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n    \"instance\": {...}\n  }'\n```\n\n\nSample response:\n\n```json\n{\n  \"message\": \"Command received.\",\n  \"command_id\": \"e9d1a8ab-....\"\n}\n```\n\n---\n\n**3. Send a chat completion**\n\nNow, make a POST to `/v1/chat/completions` (the same format as OpenAI's API):\n\n```bash\ncurl -N -X POST http://localhost:52415/v1/chat/completions \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n    \"model\": \"mlx-community/Llama-3.2-1B-Instruct-4bit\",\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"What is Llama 3.2 1B?\"}\n    ],\n    \"stream\": true\n  }'\n```\n\n---\n\n**4. Delete the instance**\n\nWhen you're done, delete the instance by its ID (find it via `/state` or `/instance` endpoints):\n\n```bash\ncurl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID\n```\n\n### Claude Messages API Compatibility\n\nUse the Claude Messages API format with the `/v1/messages` endpoint:\n\n```bash\ncurl -N -X POST http://localhost:52415/v1/messages \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n    \"model\": \"mlx-community/Llama-3.2-1B-Instruct-4bit\",\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"Hello\"}\n    ],\n    \"max_tokens\": 1024,\n    \"stream\": true\n  }'\n```\n\n### OpenAI Responses API Compatibility\n\nUse the OpenAI Responses API format with the `/v1/responses` endpoint:\n\n```bash\ncurl -N -X POST http://localhost:52415/v1/responses \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n    \"model\": \"mlx-community/Llama-3.2-1B-Instruct-4bit\",\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"Hello\"}\n    ],\n    \"stream\": true\n  }'\n```\n\n### Ollama API Compatibility\n\nexo supports Ollama API endpoints for compatibility with tools like OpenWebUI:\n\n```bash\n# Ollama chat\ncurl -X POST http://localhost:52415/ollama/api/chat \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n    \"model\": \"mlx-community/Llama-3.2-1B-Instruct-4bit\",\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"Hello\"}\n    ],\n    \"stream\": false\n  }'\n\n# List models (Ollama format)\ncurl http://localhost:52415/ollama/api/tags\n```\n\n### Custom Model Loading from HuggingFace\n\nYou can add custom models from the HuggingFace hub:\n\n```bash\ncurl -X POST http://localhost:52415/models/add \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n    \"model_id\": \"mlx-community/my-custom-model\"\n  }'\n```\n\n**Security Note:**\n\nCustom models requiring `trust_remote_code` in their configuration must be explicitly enabled (default is false) for security. Only enable this if you trust the model's remote code execution. Models are fetched from HuggingFace and stored locally as custom model cards.\n\n**Other useful API endpoints*:**\n\n- List all models: `curl http://localhost:52415/models`\n- List downloaded models only: `curl http://localhost:52415/models?status=downloaded`\n- Search HuggingFace: `curl \"http://localhost:52415/models/search?query=llama&limit=10\"`\n- Inspect instance IDs and deployment state: `curl http://localhost:52415/state`\n\nFor further details, see:\n\n- API documentation in [docs/api.md](docs/api.md).\n- API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).\n\n---\n\n## Benchmarking\n\nThe `exo-bench` tool measures model prefill and token generation speed across different placement configurations. This helps you optimize model performance and validate improvements.\n\n**Prerequisites:**\n- Nodes should be running with `uv run exo` before benchmarking\n- The tool uses the `/bench/chat/completions` endpoint\n\n**Basic usage:**\n\n```bash\nuv run bench/exo_bench.py \\\n  --model Llama-3.2-1B-Instruct-4bit \\\n  --pp 128,256,512 \\\n  --tg 128,256\n```\n\n**Key parameters:**\n\n- `--model`: Model to benchmark (short ID or HuggingFace ID)\n- `--pp`: Prompt size hints (comma-separated integers)\n- `--tg`: Generation lengths (comma-separated integers)\n- `--max-nodes`: Limit placements to N nodes (default: 4)\n- `--instance-meta`: Filter by `ring`, `jaccl`, or `both` (default: both)\n- `--sharding`: Filter by `pipeline`, `tensor`, or `both` (default: both)\n- `--repeat`: Number of repetitions per configuration (default: 1)\n- `--warmup`: Warmup runs per placement (default: 0)\n- `--json-out`: Output file for results (default: bench/results.json)\n\n**Example with filters:**\n\n```bash\nuv run bench/exo_bench.py \\\n  --model Llama-3.2-1B-Instruct-4bit \\\n  --pp 128,512 \\\n  --tg 128 \\\n  --max-nodes 2 \\\n  --sharding tensor \\\n  --repeat 3 \\\n  --json-out my-results.json\n```\n\nThe tool outputs performance metrics including prompt tokens per second (prompt_tps), generation tokens per second (generation_tps), and peak memory usage for each configuration.\n\n---\n\n## Hardware Accelerator Support\n\nOn macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working on extending hardware accelerator support. If you'd like support for a new hardware platform, please [search for an existing feature request](https://github.com/exo-explore/exo/issues) and add a thumbs up so we know what hardware is important to the community.\n\n---\n\n## Contributing\n\nSee [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.\n"
  },
  {
    "path": "RULES.md",
    "content": "# Repository Rules\n\n* if you see any code that violates these rules, raise it with me directly rather than trying to fix.\n  * where applicable, file a GitHub Issue.\n* adhere to these rules strictly.\n\n## General Rules\n\n* if its possible to eliminate an extra try-catch or if-statement at runtime using type-level discipline, do it!\n* name your types, functions, and classes appropriately.\n  * no three-letter acronyms.\n  * no non-standard contractions.\n  * each data type has a meaning, pick a name which is accurate and descriptive.\n  * the average layman should be able to easily understand what your function does using the function signature alone!\n    * sometimes, there will be exceptions. eg, when you're using specific technical terms that are well understood (saga, event, etc).\n    * usually, you'll think that your code is an exception to the rules, but it won't be.\n\n## State, Functions and Classes\n\n* every function, given the same inputs, should produce the same outputs. ie, no hidden state.\n* use classes to prevent fixed state from being mutated arbitrarily (unsafely); methods provide a safe way of interfacing with state.\n* if your logic doesn't mutate fixed state, it probably belongs in a standalone function rather than a class.\n* functions shouldn't usually produce side-effects (they should be computationally pure).\n  * if, for example, you're updating a state using an event (computationally pure), and you want to trigger a saga (computational side-effect), store the logic for triggering the saga into an effect handler (a function, capable of producing side-effects, that you pass into an otherwise computationally pure function, so that it may trigger side-effects safely).\n\n## Pydantic\n\n* read the Pydantic docs.\n* respect the Pydantic docs.\n* pydantic is all you need.\n* declare and re-use a central `ConfigDict` for your use-case, you'll usually want `frozen` and `strict` to be `True`.\n\n## Unique ID (UUID) Generation\n\n* inherit from Pydantic's `UUID4` class to create your own UUID class.\n* use `uuid.uuid4()` to initialize your class with a fresh UUID where possible.\n* ensure that idempotency tags are generated by taking the salted hash of persisted state.\n  * rationale: if a node crashes and resumes from an older state, it should not accidentally re-publish the same event twice under different idempotency tags. \n  * every distinct function should feature a unique salt, so that there are no accidental collisions in idempotency tags.\n\n## Type Wrappers\n\n* reuse types that already exist in the Python standard library.\n* when two distinct data types are structurally identical (for example, different IDs which are both UUIDs but shouldn't never mixed up), make sure they can't be conflated by the type system.\n  * if you're working with a primitive data type (`str`, `int`, etc), use `NewType` (it has zero runtime overhead).\n  * if you're working with serializable data objects, consider adding a field (type `str`) that states its type.\n\n## Type Discipline\n\n* do not bypass the type-checker, preserve strict typing by any means necessary.\n* by default, use literal types (like `Literal['one', 'two']`) where an enum seems appropriate.\n\npro-tip: Python's type system is quite complex and feature-rich, so reading the documentation is often advisable; Matt discovered that Python `typing` library allows you to check that you've implemented a `match` exhaustively using `Literal` and `get_args(type)` after reading the docs.\n\n## Use of `@final`, Freezing\n\n* use wherever applicable.\n\n## Error Handling\n\n* don't try-catch for no reason.\n* make sure that you always know where and when the exceptions your code produces are meant to be handled, so that it's never a nasty surprise.\n  * always write the rationale for your error-handling down in the docstring!\n  * communicate the details to your colleagues when appropriate.\n\n## Dependencies\n\n* don't introduce any new dependencies without asking.\n* don't ask for any dependencies that aren't ubiquitous within production environments.\n\n## Commit Messages\n\n*   use the imperative mood in the subject line.\n*   prefix the subject line with a change type. our change types are:\n    *   `documentation`: documentation changes.\n    *   `feature`: a new feature.\n    *   `refactor`: a code change that neither fixes a bug nor adds a feature.\n    *   `bugfix`: a bug fix.\n    *   `chore`: routine tasks, maintenance, or tooling changes.\n    *   `test`: adding or correcting tests.\n*   restrict the subject line to fifty characters or less.\n*   capitalize the subject line.\n*   do not end the subject line with a period.\n*   separate subject from body with a blank line."
  },
  {
    "path": "TODO.md",
    "content": "3. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.\n4. I'd like to see profiled network latency / bandwidth.\n5. I'd like to see how much bandwidth each link is using.\n7. Solve the problem of in continuous batching when a new prompt comes in, it will block decode of the current batch until the prefill is complete.\n8. We want people to be able to copy models over to a new device without ever connecting EXO to the internet. Right now EXO require internet connection once to cache some files to check if a download is complete. Instead, we should simply check if there is a non-empty model folder locally with no .partial files. This indicates it's a fully downloaded model that can be loaded.\n13. Memory pressure instead of memory used.\n14. Show the type of each connection (TB5, Ethernet, etc.) in the UI. Refer to old exo: https://github.com/exo-explore/exo/blob/56f783b38dc6b08ce606b07a5386dc40dae00330/exo/helpers.py#L251\n15. Prioritise certain connection types (or by latency). TB5 > Ethernet > WiFi. Refer to old exo: https://github.com/exo-explore/exo/blob/56f783b38dc6b08ce606b07a5386dc40dae00330/exo/helpers.py#L251\n16. Dynamically switch to higher priority connection when it becomes available. Probably bring back InstanceReplacedAtomically.\n17. Faster model loads by streaming model from other devices in cluster.\n18. Add support for specifying the type of network connection to use in a test. Depends on 15/16.\n25. Rethink retry logic\n27. Log cleanup - per-module log filters and default to DEBUG log levels\n28. Validate RDMA connections with ibv_devinfo in the info gatherer\n"
  },
  {
    "path": "app/EXO/EXO/Assets.xcassets/AccentColor.colorset/Contents.json",
    "content": "{\n  \"colors\" : [\n    {\n      \"idiom\" : \"universal\"\n    }\n  ],\n  \"info\" : {\n    \"author\" : \"xcode\",\n    \"version\" : 1\n  }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Assets.xcassets/AppIcon.appiconset/Contents.json",
    "content": "{\n  \"images\" : [\n    {\n      \"filename\" : \"16-mac.png\",\n      \"idiom\" : \"mac\",\n      \"scale\" : \"1x\",\n      \"size\" : \"16x16\"\n    },\n    {\n      \"filename\" : \"32-mac.png\",\n      \"idiom\" : \"mac\",\n      \"scale\" : \"2x\",\n      \"size\" : \"16x16\"\n    },\n    {\n      \"filename\" : \"32-mac 1.png\",\n      \"idiom\" : \"mac\",\n      \"scale\" : \"1x\",\n      \"size\" : \"32x32\"\n    },\n    {\n      \"filename\" : \"64-mac.png\",\n      \"idiom\" : \"mac\",\n      \"scale\" : \"2x\",\n      \"size\" : \"32x32\"\n    },\n    {\n      \"filename\" : \"128-mac.png\",\n      \"idiom\" : \"mac\",\n      \"scale\" : \"1x\",\n      \"size\" : \"128x128\"\n    },\n    {\n      \"filename\" : \"256-mac 1.png\",\n      \"idiom\" : \"mac\",\n      \"scale\" : \"2x\",\n      \"size\" : \"128x128\"\n    },\n    {\n      \"filename\" : \"256-mac.png\",\n      \"idiom\" : \"mac\",\n      \"scale\" : \"1x\",\n      \"size\" : \"256x256\"\n    },\n    {\n      \"filename\" : \"512-mac.png\",\n      \"idiom\" : \"mac\",\n      \"scale\" : \"2x\",\n      \"size\" : \"256x256\"\n    },\n    {\n      \"filename\" : \"512-mac 1.png\",\n      \"idiom\" : \"mac\",\n      \"scale\" : \"1x\",\n      \"size\" : \"512x512\"\n    },\n    {\n      \"filename\" : \"1024-mac.png\",\n      \"idiom\" : \"mac\",\n      \"scale\" : \"2x\",\n      \"size\" : \"512x512\"\n    }\n  ],\n  \"info\" : {\n    \"author\" : \"xcode\",\n    \"version\" : 1\n  }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Assets.xcassets/Contents.json",
    "content": "{\n  \"info\" : {\n    \"author\" : \"xcode\",\n    \"version\" : 1\n  }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Assets.xcassets/menubar-icon.imageset/Contents.json",
    "content": "{\n  \"images\" : [\n    {\n      \"filename\" : \"exo-logo-hq-square-transparent-bg.png\",\n      \"idiom\" : \"universal\",\n      \"scale\" : \"1x\"\n    },\n    {\n      \"idiom\" : \"universal\",\n      \"scale\" : \"2x\"\n    },\n    {\n      \"idiom\" : \"universal\",\n      \"scale\" : \"3x\"\n    }\n  ],\n  \"info\" : {\n    \"author\" : \"xcode\",\n    \"version\" : 1\n  }\n}\n"
  },
  {
    "path": "app/EXO/EXO/ContentView.swift",
    "content": "//\n//  ContentView.swift\n//  EXO\n//\n//  Created by Sami Khan on 2025-11-22.\n//\n\nimport AppKit\nimport SwiftUI\n\nstruct ContentView: View {\n    @EnvironmentObject private var controller: ExoProcessController\n    @EnvironmentObject private var stateService: ClusterStateService\n    @EnvironmentObject private var networkStatusService: NetworkStatusService\n    @EnvironmentObject private var localNetworkChecker: LocalNetworkChecker\n    @EnvironmentObject private var updater: SparkleUpdater\n    @EnvironmentObject private var thunderboltBridgeService: ThunderboltBridgeService\n    @EnvironmentObject private var settingsWindowController: SettingsWindowController\n    @State private var focusedNode: NodeViewModel?\n    @State private var deletingInstanceIDs: Set<String> = []\n    @State private var showAllNodes = false\n    @State private var showAllInstances = false\n    @State private var baseURLCopied = false\n    @State private var showAdvanced = false\n    @State private var showDebugInfo = false\n    private enum BugReportPhase: Equatable {\n        case idle\n        case prompting\n        case sending(String)\n        case success(String)\n        case failure(String)\n    }\n    @State private var bugReportPhase: BugReportPhase = .idle\n    @State private var bugReportUserDescription: String = \"\"\n    @State private var uninstallInProgress = false\n    @State private var pendingNamespace: String = \"\"\n    @State private var pendingHFToken: String = \"\"\n    @State private var pendingEnableImageModels = false\n\n    var body: some View {\n        VStack(alignment: .leading, spacing: 12) {\n            statusSection\n            if shouldShowLocalNetworkWarning {\n                localNetworkWarningBanner\n            }\n            if shouldShowClusterDetails {\n                Divider()\n                overviewSection\n                topologySection\n                nodeSection\n            }\n            if shouldShowInstances {\n                instanceSection\n            }\n            Spacer(minLength: 0)\n            controlButtons\n        }\n        .animation(.easeInOut(duration: 0.3), value: shouldShowClusterDetails)\n        .animation(.easeInOut(duration: 0.3), value: shouldShowInstances)\n        .animation(.easeInOut(duration: 0.3), value: shouldShowLocalNetworkWarning)\n        .padding()\n        .frame(width: 340)\n        .onAppear {\n            Task {\n                await networkStatusService.refresh()\n            }\n        }\n    }\n\n    private var shouldShowLocalNetworkWarning: Bool {\n        // Show warning if local network is not working and EXO is running.\n        // The checker uses a longer timeout on first launch to allow time for\n        // the permission prompt, so this correctly handles both:\n        // 1. User denied permission on first launch\n        // 2. Permission broke after restart (macOS TCC bug)\n        if case .notWorking = localNetworkChecker.status {\n            return controller.status != .stopped\n        }\n        return false\n    }\n\n    private var localNetworkWarningBanner: some View {\n        VStack(alignment: .leading, spacing: 6) {\n            HStack(spacing: 6) {\n                Image(systemName: \"exclamationmark.triangle.fill\")\n                    .foregroundColor(.orange)\n                Text(\"Local Network Access Issue\")\n                    .font(.caption)\n                    .fontWeight(.semibold)\n            }\n            Text(\n                \"Device discovery won't work. To fix:\\n1. Quit EXO\\n2. Open System Settings → Privacy & Security → Local Network\\n3. Toggle EXO off, then back on\\n4. Relaunch EXO\"\n            )\n            .font(.caption2)\n            .foregroundColor(.secondary)\n            .fixedSize(horizontal: false, vertical: true)\n            Button {\n                openLocalNetworkSettings()\n            } label: {\n                Text(\"Open Settings\")\n                    .font(.caption2)\n            }\n            .buttonStyle(.bordered)\n            .controlSize(.small)\n        }\n        .padding(8)\n        .background(\n            RoundedRectangle(cornerRadius: 8)\n                .fill(Color.orange.opacity(0.1))\n        )\n        .overlay(\n            RoundedRectangle(cornerRadius: 8)\n                .stroke(Color.orange.opacity(0.3), lineWidth: 1)\n        )\n    }\n\n    private func openLocalNetworkSettings() {\n        // Open Privacy & Security settings - Local Network section\n        if let url = URL(\n            string: \"x-apple.systempreferences:com.apple.preference.security?Privacy_LocalNetwork\")\n        {\n            NSWorkspace.shared.open(url)\n        }\n    }\n\n    private var topologySection: some View {\n        Group {\n            if let topology = stateService.latestSnapshot?.topologyViewModel(\n                localNodeId: stateService.localNodeId), !topology.nodes.isEmpty\n            {\n                TopologyMiniView(topology: topology)\n            }\n        }\n    }\n\n    private var statusSection: some View {\n        HStack(spacing: 8) {\n            VStack(alignment: .leading, spacing: 2) {\n                Text(\"EXO\")\n                    .font(.headline)\n                Text(controller.status.displayText)\n                    .font(.caption)\n                    .foregroundColor(.secondary)\n                if let detail = statusDetailText {\n                    Text(detail)\n                        .font(.caption2)\n                        .foregroundColor(.secondary)\n                }\n            }\n            Spacer()\n            Toggle(\"\", isOn: processToggleBinding)\n                .toggleStyle(.switch)\n                .labelsHidden()\n        }\n    }\n\n    private var overviewSection: some View {\n        Group {\n            if let snapshot = stateService.latestSnapshot {\n                let overview = snapshot.overview()\n                VStack(alignment: .leading, spacing: 4) {\n                    HStack {\n                        VStack(alignment: .leading) {\n                            Text(\n                                \"\\(overview.usedRam, specifier: \"%.0f\") / \\(overview.totalRam, specifier: \"%.0f\") GB\"\n                            )\n                            .font(.headline)\n                            Text(\"Memory\")\n                                .font(.caption)\n                                .foregroundColor(.secondary)\n                        }\n                        Spacer()\n                        VStack(alignment: .leading) {\n                            Text(\"\\(overview.nodeCount)\")\n                                .font(.headline)\n                            Text(\"Nodes\")\n                                .font(.caption)\n                                .foregroundColor(.secondary)\n                        }\n                        Spacer()\n                        VStack(alignment: .leading) {\n                            Text(\"\\(overview.instanceCount)\")\n                                .font(.headline)\n                            Text(\"Instances\")\n                                .font(.caption)\n                                .foregroundColor(.secondary)\n                        }\n                    }\n                }\n            } else {\n                Text(\"Connecting to EXO…\")\n                    .font(.caption)\n                    .foregroundColor(.secondary)\n            }\n        }\n    }\n\n    private var nodeSection: some View {\n        Group {\n            if let nodes = stateService.latestSnapshot?.nodeViewModels(), !nodes.isEmpty {\n                VStack(alignment: .leading, spacing: 4) {\n                    HStack {\n                        Text(\"Nodes\")\n                            .font(.caption)\n                            .foregroundColor(.secondary)\n                        Text(\"(\\(nodes.count))\")\n                            .font(.caption)\n                            .foregroundColor(.secondary)\n                        Spacer()\n                        collapseButton(isExpanded: $showAllNodes)\n                    }\n                    .animation(nil, value: showAllNodes)\n                    if showAllNodes {\n                        VStack(alignment: .leading, spacing: 8) {\n                            ForEach(nodes) { node in\n                                NodeRowView(node: node)\n                                    .padding(.horizontal, 6)\n                                    .padding(.vertical, 4)\n                                    .background(.regularMaterial.opacity(0.6))\n                                    .clipShape(RoundedRectangle(cornerRadius: 6))\n                            }\n                        }\n                        .transition(.opacity)\n                    }\n                }\n                .animation(.easeInOut(duration: 0.25), value: showAllNodes)\n            }\n        }\n    }\n\n    private var instanceSection: some View {\n        Group {\n            if let instances = stateService.latestSnapshot?.instanceViewModels() {\n                VStack(alignment: .leading, spacing: 8) {\n                    HStack {\n                        Text(\"Instances\")\n                            .font(.caption)\n                            .foregroundColor(.secondary)\n                        Text(\"(\\(instances.count))\")\n                            .font(.caption)\n                            .foregroundColor(.secondary)\n                        Spacer()\n                        if !instances.isEmpty {\n                            collapseButton(isExpanded: $showAllInstances)\n                        }\n                    }\n                    .animation(nil, value: showAllInstances)\n                    if showAllInstances, !instances.isEmpty {\n                        VStack(alignment: .leading, spacing: 8) {\n                            ForEach(instances) { instance in\n                                InstanceRowView(instance: instance)\n                                    .padding(.horizontal, 6)\n                                    .padding(.vertical, 4)\n                                    .background(.regularMaterial.opacity(0.6))\n                                    .clipShape(RoundedRectangle(cornerRadius: 6))\n                            }\n                        }\n                        .transition(.opacity)\n                    }\n                }\n                .animation(.easeInOut(duration: 0.25), value: showAllInstances)\n            }\n        }\n    }\n\n    private var controlButtons: some View {\n        VStack(alignment: .leading, spacing: 0) {\n            if controller.status != .stopped {\n                dashboardButton\n                baseURLRow\n                Divider()\n                    .padding(.vertical, 8)\n            } else {\n                Divider()\n                    .padding(.vertical, 4)\n            }\n            HoverButton(\n                title: \"Settings\",\n                tint: .primary,\n                trailingSystemImage: \"gear\"\n            ) {\n                settingsWindowController.open(\n                    controller: controller,\n                    updater: updater,\n                    networkStatusService: networkStatusService,\n                    thunderboltBridgeService: thunderboltBridgeService,\n                    stateService: stateService\n                )\n            }\n            HoverButton(\n                title: \"Check for Updates\",\n                tint: .primary,\n                trailingSystemImage: \"arrow.triangle.2.circlepath\"\n            ) {\n                updater.checkForUpdates()\n            }\n            .padding(.bottom, 8)\n            HoverButton(title: \"Quit\", tint: .secondary) {\n                controller.stop()\n                NSApplication.shared.terminate(nil)\n            }\n        }\n    }\n\n    private var dashboardButton: some View {\n        HoverButton(\n            title: \"Web Dashboard\",\n            tint: .primary,\n            trailingSystemImage: \"arrow.up.right\"\n        ) {\n            guard let url = URL(string: \"http://localhost:52415/\") else { return }\n            NSWorkspace.shared.open(url)\n        }\n    }\n\n    private var baseURLRow: some View {\n        HStack(spacing: 6) {\n            Image(systemName: \"link\")\n                .imageScale(.small)\n                .foregroundColor(.secondary)\n            Text(\"localhost:52415/v1\")\n                .font(.system(.caption, design: .monospaced))\n                .foregroundColor(.primary)\n            Spacer()\n            Button {\n                NSPasteboard.general.clearContents()\n                NSPasteboard.general.setString(\"http://localhost:52415/v1\", forType: .string)\n                baseURLCopied = true\n                DispatchQueue.main.asyncAfter(deadline: .now() + 2) {\n                    baseURLCopied = false\n                }\n            } label: {\n                Image(systemName: baseURLCopied ? \"checkmark\" : \"doc.on.doc\")\n                    .imageScale(.small)\n                    .foregroundColor(baseURLCopied ? .green : .secondary)\n                    .contentTransition(.symbolEffect(.replace))\n            }\n            .buttonStyle(.plain)\n            .help(\"Copy API base URL\")\n        }\n        .padding(.vertical, 4)\n        .padding(.horizontal, 8)\n    }\n\n    private func collapseButton(isExpanded: Binding<Bool>) -> some View {\n        Button {\n            isExpanded.wrappedValue.toggle()\n        } label: {\n            Label(\n                isExpanded.wrappedValue ? \"Hide\" : \"Show All\",\n                systemImage: isExpanded.wrappedValue ? \"chevron.up\" : \"chevron.down\"\n            )\n            .labelStyle(.titleAndIcon)\n            .contentTransition(.symbolEffect(.replace))\n        }\n        .buttonStyle(.plain)\n        .font(.caption2)\n    }\n    private func instancesToDisplay(_ instances: [InstanceViewModel]) -> [InstanceViewModel] {\n        if showAllInstances {\n            return instances\n        }\n        return []\n    }\n\n    private var shouldShowClusterDetails: Bool {\n        controller.status != .stopped\n    }\n\n    private var shouldShowInstances: Bool {\n        controller.status != .stopped\n    }\n\n    private var statusDetailText: String? {\n        switch controller.status {\n        case .failed(let message):\n            return message\n        case .stopped:\n            if let countdown = controller.launchCountdownSeconds {\n                return \"Launching in \\(countdown)s\"\n            }\n            return nil\n        default:\n            if let countdown = controller.launchCountdownSeconds {\n                return \"Launching in \\(countdown)s\"\n            }\n            if let lastError = controller.lastError {\n                return lastError\n            }\n            if let message = stateService.lastActionMessage {\n                return message\n            }\n            return nil\n        }\n    }\n\n    private var thunderboltStatusText: String {\n        switch networkStatusService.status.thunderboltBridgeState {\n        case .some(.disabled):\n            return \"Thunderbolt Bridge: Disabled\"\n        case .some(.deleted):\n            return \"Thunderbolt Bridge: Deleted\"\n        case .some(.enabled):\n            return \"Thunderbolt Bridge: Enabled\"\n        case nil:\n            return \"Thunderbolt Bridge: Unknown\"\n        }\n    }\n\n    private var thunderboltStatusColor: Color {\n        switch networkStatusService.status.thunderboltBridgeState {\n        case .some(.disabled), .some(.deleted):\n            return .green\n        case .some(.enabled):\n            return .red\n        case nil:\n            return .secondary\n        }\n    }\n\n    /// Shows TB bridge status for all nodes from exo cluster state\n    private var clusterThunderboltBridgeView: some View {\n        let bridgeStatuses = stateService.latestSnapshot?.nodeThunderboltBridge ?? [:]\n        let localNodeId = stateService.localNodeId\n        let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]\n\n        return VStack(alignment: .leading, spacing: 1) {\n            if bridgeStatuses.isEmpty {\n                Text(\"Cluster TB Bridge: No data\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n            } else {\n                Text(\"Cluster TB Bridge Status:\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n                ForEach(Array(bridgeStatuses.keys.sorted()), id: \\.self) { nodeId in\n                    if let status = bridgeStatuses[nodeId] {\n                        let nodeName =\n                            nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))\n                        let isLocal = nodeId == localNodeId\n                        let prefix = isLocal ? \"  \\(nodeName) (local):\" : \"  \\(nodeName):\"\n                        let statusText =\n                            !status.exists\n                            ? \"N/A\"\n                            : (status.enabled ? \"Enabled\" : \"Disabled\")\n                        let color: Color =\n                            !status.exists\n                            ? .secondary\n                            : (status.enabled ? .red : .green)\n                        Text(\"\\(prefix) \\(statusText)\")\n                            .font(.caption2)\n                            .foregroundColor(color)\n                    }\n                }\n            }\n        }\n    }\n\n    private var interfaceIpList: some View {\n        let statuses = networkStatusService.status.interfaceStatuses\n        return VStack(alignment: .leading, spacing: 1) {\n            Text(\"Interfaces (en0–en7):\")\n                .font(.caption2)\n                .foregroundColor(.secondary)\n            if statuses.isEmpty {\n                Text(\"  Unknown\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n            } else {\n                ForEach(statuses, id: \\.interfaceName) { status in\n                    let ipText = status.ipAddress ?? \"No IP\"\n                    Text(\"  \\(status.interfaceName): \\(ipText)\")\n                        .font(.caption2)\n                        .foregroundColor(status.ipAddress == nil ? .red : .green)\n                }\n            }\n        }\n    }\n\n    private var debugSection: some View {\n        VStack(alignment: .leading, spacing: 4) {\n            HoverButton(\n                title: \"Debug Info\",\n                tint: .primary,\n                trailingSystemImage: showDebugInfo ? \"chevron.up\" : \"chevron.down\",\n                small: true\n            ) {\n                showDebugInfo.toggle()\n            }\n            if showDebugInfo {\n                VStack(alignment: .leading, spacing: 4) {\n                    Text(\"Version: \\(buildTag)\")\n                        .font(.caption2)\n                        .foregroundColor(.secondary)\n                    Text(\"Commit: \\(buildCommit)\")\n                        .font(.caption2)\n                        .foregroundColor(.secondary)\n                    Text(thunderboltStatusText)\n                        .font(.caption2)\n                        .foregroundColor(thunderboltStatusColor)\n                    clusterThunderboltBridgeView\n                    interfaceIpList\n                    rdmaStatusView\n                    sendBugReportButton\n                        .padding(.top, 6)\n                }\n                .padding(.leading, 8)\n                .transition(.opacity)\n            }\n        }\n        .animation(.easeInOut(duration: 0.25), value: showDebugInfo)\n    }\n\n    private var rdmaStatusView: some View {\n        let rdmaStatuses = stateService.latestSnapshot?.nodeRdmaCtl ?? [:]\n        let localNodeId = stateService.localNodeId\n        let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]\n        let localDevices = networkStatusService.status.localRdmaDevices\n        let localPorts = networkStatusService.status.localRdmaActivePorts\n\n        return VStack(alignment: .leading, spacing: 1) {\n            if rdmaStatuses.isEmpty {\n                Text(\"Cluster RDMA: No data\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n            } else {\n                Text(\"Cluster RDMA Status:\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n                ForEach(Array(rdmaStatuses.keys.sorted()), id: \\.self) { nodeId in\n                    if let status = rdmaStatuses[nodeId] {\n                        let nodeName =\n                            nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))\n                        let isLocal = nodeId == localNodeId\n                        let prefix = isLocal ? \"  \\(nodeName) (local):\" : \"  \\(nodeName):\"\n                        let statusText = status.enabled ? \"Enabled\" : \"Disabled\"\n                        let color: Color = status.enabled ? .green : .orange\n                        Text(\"\\(prefix) \\(statusText)\")\n                            .font(.caption2)\n                            .foregroundColor(color)\n                    }\n                }\n            }\n            if !localDevices.isEmpty {\n                Text(\"  Local Devices: \\(localDevices.joined(separator: \", \"))\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n            }\n            if !localPorts.isEmpty {\n                Text(\"  Local Active Ports:\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n                ForEach(localPorts, id: \\.device) { port in\n                    Text(\"    \\(port.device) port \\(port.port): \\(port.state)\")\n                        .font(.caption2)\n                        .foregroundColor(.green)\n                }\n            }\n        }\n    }\n\n    private var sendBugReportButton: some View {\n        VStack(alignment: .leading, spacing: 6) {\n            switch bugReportPhase {\n            case .idle:\n                Button {\n                    bugReportPhase = .prompting\n                    bugReportUserDescription = \"\"\n                } label: {\n                    HStack {\n                        Text(\"Send Bug Report\")\n                            .font(.caption)\n                            .fontWeight(.semibold)\n                        Spacer()\n                    }\n                    .padding(.vertical, 6)\n                    .padding(.horizontal, 8)\n                    .background(\n                        RoundedRectangle(cornerRadius: 6)\n                            .fill(Color.accentColor.opacity(0.12))\n                    )\n                }\n                .buttonStyle(.plain)\n\n            case .prompting:\n                VStack(alignment: .leading, spacing: 6) {\n                    Text(\"What's the issue? (optional)\")\n                        .font(.caption2)\n                        .foregroundColor(.secondary)\n                    TextEditor(text: $bugReportUserDescription)\n                        .font(.caption2)\n                        .frame(height: 60)\n                        .overlay(\n                            RoundedRectangle(cornerRadius: 4)\n                                .stroke(Color.secondary.opacity(0.3), lineWidth: 1)\n                        )\n                    HStack(spacing: 8) {\n                        Button(\"Send\") {\n                            Task {\n                                await sendBugReport()\n                            }\n                        }\n                        .font(.caption2)\n                        .buttonStyle(.borderedProminent)\n                        .controlSize(.small)\n                        Button(\"Cancel\") {\n                            bugReportPhase = .idle\n                        }\n                        .font(.caption2)\n                        .buttonStyle(.bordered)\n                        .controlSize(.small)\n                    }\n                }\n                .padding(8)\n                .background(\n                    RoundedRectangle(cornerRadius: 6)\n                        .fill(Color.accentColor.opacity(0.06))\n                )\n\n            case .sending(let message):\n                HStack(spacing: 6) {\n                    ProgressView()\n                        .scaleEffect(0.6)\n                    Text(message)\n                        .font(.caption2)\n                        .foregroundColor(.secondary)\n                }\n\n            case .success(let message):\n                VStack(alignment: .leading, spacing: 6) {\n                    Text(message)\n                        .font(.caption2)\n                        .foregroundColor(.secondary)\n                        .fixedSize(horizontal: false, vertical: true)\n                    Button {\n                        openGitHubIssue()\n                    } label: {\n                        HStack(spacing: 4) {\n                            Image(systemName: \"arrow.up.right.square\")\n                                .imageScale(.small)\n                            Text(\"Create GitHub Issue\")\n                                .font(.caption2)\n                        }\n                    }\n                    .buttonStyle(.bordered)\n                    .controlSize(.small)\n                    Button(\"Done\") {\n                        bugReportPhase = .idle\n                        bugReportUserDescription = \"\"\n                    }\n                    .font(.caption2)\n                    .buttonStyle(.plain)\n                    .foregroundColor(.secondary)\n                }\n\n            case .failure(let message):\n                VStack(alignment: .leading, spacing: 4) {\n                    Text(message)\n                        .font(.caption2)\n                        .foregroundColor(.red)\n                        .fixedSize(horizontal: false, vertical: true)\n                    Button(\"Dismiss\") {\n                        bugReportPhase = .idle\n                    }\n                    .font(.caption2)\n                    .buttonStyle(.plain)\n                    .foregroundColor(.secondary)\n                }\n            }\n        }\n        .animation(.easeInOut(duration: 0.2), value: bugReportPhase)\n    }\n\n    private var processToggleBinding: Binding<Bool> {\n        Binding(\n            get: {\n                switch controller.status {\n                case .running, .starting:\n                    return true\n                case .stopped, .failed:\n                    return false\n                }\n            },\n            set: { isOn in\n                if isOn {\n                    stateService.resetTransientState()\n                    stateService.startPolling()\n                    controller.cancelPendingLaunch()\n                    controller.launchIfNeeded()\n                } else {\n                    stateService.stopPolling()\n                    controller.stop()\n                    stateService.resetTransientState()\n                }\n            }\n        )\n    }\n\n    private func bindingForNode(_ node: NodeViewModel) -> Binding<NodeViewModel?> {\n        Binding<NodeViewModel?>(\n            get: {\n                focusedNode?.id == node.id ? focusedNode : nil\n            },\n            set: { newValue in\n                if newValue == nil {\n                    focusedNode = nil\n                } else {\n                    focusedNode = newValue\n                }\n            }\n        )\n    }\n\n    private func sendBugReport() async {\n        bugReportPhase = .sending(\"Collecting logs...\")\n        let service = BugReportService()\n        let description = bugReportUserDescription.trimmingCharacters(in: .whitespacesAndNewlines)\n        do {\n            let outcome = try await service.sendReport(\n                isManual: true,\n                userDescription: description.isEmpty ? nil : description\n            )\n            if outcome.success {\n                bugReportPhase = .success(outcome.message)\n            } else {\n                bugReportPhase = .failure(outcome.message)\n            }\n        } catch {\n            bugReportPhase = .failure(error.localizedDescription)\n        }\n    }\n\n    private func openGitHubIssue() {\n        let description = bugReportUserDescription.trimmingCharacters(in: .whitespacesAndNewlines)\n\n        var bodyParts: [String] = []\n        bodyParts.append(\"## Describe the bug\")\n        bodyParts.append(\"\")\n        if !description.isEmpty {\n            bodyParts.append(description)\n        } else {\n            bodyParts.append(\"A clear and concise description of what the bug is.\")\n        }\n        bodyParts.append(\"\")\n        bodyParts.append(\"## Environment\")\n        bodyParts.append(\"\")\n        bodyParts.append(\"- macOS Version: \\(ProcessInfo.processInfo.operatingSystemVersionString)\")\n        bodyParts.append(\"- EXO Version: \\(buildTag) (\\(buildCommit))\")\n        bodyParts.append(\"\")\n        bodyParts.append(\"## Additional context\")\n        bodyParts.append(\"\")\n        bodyParts.append(\"A bug report with diagnostic logs was submitted via the app.\")\n\n        let body = bodyParts.joined(separator: \"\\n\")\n\n        var components = URLComponents(string: \"https://github.com/exo-explore/exo/issues/new\")!\n        components.queryItems = [\n            URLQueryItem(name: \"template\", value: \"bug_report.md\"),\n            URLQueryItem(name: \"title\", value: \"[BUG] \"),\n            URLQueryItem(name: \"body\", value: body),\n            URLQueryItem(name: \"labels\", value: \"bug\"),\n        ]\n\n        if let url = components.url {\n            NSWorkspace.shared.open(url)\n        }\n    }\n\n    private func showUninstallConfirmationAlert() {\n        let alert = NSAlert()\n        alert.messageText = \"Uninstall EXO\"\n        alert.informativeText = \"\"\"\n            This will remove EXO and all its system components:\n\n            • Network configuration daemon\n            • Launch at login registration\n            • EXO network location\n\n            The app will be moved to Trash.\n            \"\"\"\n        alert.alertStyle = .warning\n        alert.addButton(withTitle: \"Uninstall\")\n        alert.addButton(withTitle: \"Cancel\")\n\n        // Style the Uninstall button as destructive\n        if let uninstallButton = alert.buttons.first {\n            uninstallButton.hasDestructiveAction = true\n        }\n\n        let response = alert.runModal()\n        if response == .alertFirstButtonReturn {\n            performUninstall()\n        }\n    }\n\n    private func performUninstall() {\n        uninstallInProgress = true\n\n        // Stop EXO process first\n        controller.cancelPendingLaunch()\n        controller.stop()\n        stateService.stopPolling()\n\n        // Run the privileged uninstall on a background thread\n        // Using .utility QoS to avoid priority inversion with NSAppleScript's subprocess\n        DispatchQueue.global(qos: .utility).async {\n            do {\n                // Remove network setup daemon and components (requires admin privileges)\n                try NetworkSetupHelper.uninstall()\n\n                DispatchQueue.main.async {\n                    // Unregister from launch at login\n                    LaunchAtLoginHelper.disable()\n\n                    // Move app to trash\n                    self.moveAppToTrash()\n\n                    // Quit the app\n                    DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) {\n                        NSApplication.shared.terminate(nil)\n                    }\n                }\n            } catch {\n                DispatchQueue.main.async {\n                    self.showErrorAlert(message: error.localizedDescription)\n                    self.uninstallInProgress = false\n                }\n            }\n        }\n    }\n\n    private func showErrorAlert(message: String) {\n        let alert = NSAlert()\n        alert.messageText = \"Uninstall Failed\"\n        alert.informativeText = message\n        alert.alertStyle = .critical\n        alert.addButton(withTitle: \"OK\")\n        alert.runModal()\n    }\n\n    private func moveAppToTrash() {\n        guard let appURL = Bundle.main.bundleURL as URL? else { return }\n        do {\n            try FileManager.default.trashItem(at: appURL, resultingItemURL: nil)\n        } catch {\n            // If we can't trash the app, that's OK - user can do it manually\n            // The important system components have already been cleaned up\n        }\n    }\n\n    private var buildTag: String {\n        Bundle.main.infoDictionary?[\"EXOBuildTag\"] as? String ?? \"unknown\"\n    }\n\n    private var buildCommit: String {\n        Bundle.main.infoDictionary?[\"EXOBuildCommit\"] as? String ?? \"unknown\"\n    }\n}\n\nprivate struct HoverButton: View {\n    let title: String\n    let tint: Color\n    let trailingSystemImage: String?\n    let small: Bool\n    let action: () -> Void\n\n    init(\n        title: String, tint: Color = .primary, trailingSystemImage: String? = nil,\n        small: Bool = false, action: @escaping () -> Void\n    ) {\n        self.title = title\n        self.tint = tint\n        self.trailingSystemImage = trailingSystemImage\n        self.small = small\n        self.action = action\n    }\n\n    @State private var isHovering = false\n\n    var body: some View {\n        Button(action: action) {\n            HStack {\n                Text(title)\n                    .font(small ? .caption : nil)\n                Spacer()\n                if let systemName = trailingSystemImage {\n                    Image(systemName: systemName)\n                        .imageScale(.small)\n                }\n            }\n            .frame(maxWidth: .infinity, alignment: .leading)\n            .padding(.vertical, small ? 4 : 6)\n            .padding(.horizontal, small ? 6 : 8)\n            .background(\n                RoundedRectangle(cornerRadius: 6)\n                    .fill(\n                        isHovering\n                            ? Color.accentColor.opacity(0.1)\n                            : Color.clear\n                    )\n            )\n        }\n        .buttonStyle(.plain)\n        .foregroundColor(tint)\n        .onHover { isHovering = $0 }\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/EXO.entitlements",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">\n<plist version=\"1.0\">\n<dict>\n\t<key>com.apple.security.app-sandbox</key>\n\t<false/>\n\t<key>com.apple.security.automation.apple-events</key>\n\t<true/>\n\t<key>com.apple.security.files.user-selected.read-only</key>\n\t<true/>\n</dict>\n</plist>\n"
  },
  {
    "path": "app/EXO/EXO/EXOApp.swift",
    "content": "//\n//  EXOApp.swift\n//  EXO\n//\n//  Created by Sami Khan on 2025-11-22.\n//\n\nimport AppKit\nimport CoreImage\nimport CoreImage.CIFilterBuiltins\nimport ServiceManagement\nimport Sparkle\nimport SwiftUI\nimport UserNotifications\nimport os.log\n\nstruct EXOApp: App {\n    @StateObject private var controller: ExoProcessController\n    @StateObject private var stateService: ClusterStateService\n    @StateObject private var networkStatusService: NetworkStatusService\n    @StateObject private var localNetworkChecker: LocalNetworkChecker\n    @StateObject private var updater: SparkleUpdater\n    @StateObject private var thunderboltBridgeService: ThunderboltBridgeService\n    @StateObject private var settingsWindowController: SettingsWindowController\n    private let terminationObserver: TerminationObserver\n    private let firstLaunchPopout = FirstLaunchPopout()\n    private let ciContext = CIContext(options: nil)\n\n    init() {\n        let controller = ExoProcessController()\n        let updater = SparkleUpdater(processController: controller)\n        terminationObserver = TerminationObserver {\n            Task { @MainActor in\n                controller.cancelPendingLaunch()\n                controller.stop()\n            }\n        }\n        _controller = StateObject(wrappedValue: controller)\n        let service = ClusterStateService()\n        _stateService = StateObject(wrappedValue: service)\n        let networkStatus = NetworkStatusService()\n        _networkStatusService = StateObject(wrappedValue: networkStatus)\n        let localNetwork = LocalNetworkChecker()\n        _localNetworkChecker = StateObject(wrappedValue: localNetwork)\n        _updater = StateObject(wrappedValue: updater)\n        let thunderboltBridge = ThunderboltBridgeService(clusterStateService: service)\n        _thunderboltBridgeService = StateObject(wrappedValue: thunderboltBridge)\n        _settingsWindowController = StateObject(wrappedValue: SettingsWindowController())\n        enableLaunchAtLoginIfNeeded()\n        // Install LaunchDaemon to disable Thunderbolt Bridge on startup (prevents network loops)\n        NetworkSetupHelper.promptAndInstallIfNeeded()\n        // Check local network access periodically (warning disappears when user grants permission)\n        localNetwork.startPeriodicChecking(interval: 10)\n        controller.scheduleLaunch(after: 5)\n        service.startPolling()\n        networkStatus.startPolling()\n    }\n\n    var body: some Scene {\n        MenuBarExtra {\n            ContentView()\n                .environmentObject(controller)\n                .environmentObject(stateService)\n                .environmentObject(networkStatusService)\n                .environmentObject(localNetworkChecker)\n                .environmentObject(updater)\n                .environmentObject(thunderboltBridgeService)\n                .environmentObject(settingsWindowController)\n        } label: {\n            menuBarIcon\n                .onReceive(controller.$isFirstLaunchReady) { ready in\n                    if ready {\n                        DispatchQueue.main.asyncAfter(deadline: .now() + 3.0) {\n                            self.firstLaunchPopout.onComplete = { [weak controller] in\n                                controller?.markOnboardingCompleted()\n                            }\n                            self.firstLaunchPopout.show()\n                        }\n                    }\n                }\n        }\n        .menuBarExtraStyle(.window)\n    }\n\n    private var menuBarIcon: some View {\n        let baseImage = resizedMenuBarIcon(named: \"menubar-icon\", size: 26)\n        let iconImage: NSImage\n        if controller.status == .stopped, let grey = greyscale(image: baseImage) {\n            iconImage = grey\n        } else {\n            iconImage = baseImage ?? NSImage(named: \"menubar-icon\") ?? NSImage()\n        }\n        return Image(nsImage: iconImage)\n            .accessibilityLabel(\"EXO\")\n    }\n\n    private func resizedMenuBarIcon(named: String, size: CGFloat) -> NSImage? {\n        guard let original = NSImage(named: named) else {\n            print(\"Failed to load image named: \\(named)\")\n            return nil\n        }\n        let targetSize = NSSize(width: size, height: size)\n        let resized = NSImage(size: targetSize)\n        resized.lockFocus()\n        defer { resized.unlockFocus() }\n        NSGraphicsContext.current?.imageInterpolation = .high\n        original.draw(\n            in: NSRect(origin: .zero, size: targetSize),\n            from: NSRect(origin: .zero, size: original.size),\n            operation: .copy,\n            fraction: 1.0\n        )\n        return resized\n    }\n\n    private func greyscale(image: NSImage?) -> NSImage? {\n        guard\n            let image,\n            let tiff = image.tiffRepresentation,\n            let bitmap = NSBitmapImageRep(data: tiff),\n            let cgImage = bitmap.cgImage\n        else {\n            return nil\n        }\n\n        let ciImage = CIImage(cgImage: cgImage)\n        let filter = CIFilter.colorControls()\n        filter.inputImage = ciImage\n        filter.saturation = 0\n        filter.brightness = -0.2\n        filter.contrast = 0.9\n\n        guard let output = filter.outputImage,\n            let rendered = ciContext.createCGImage(output, from: output.extent)\n        else {\n            return nil\n        }\n\n        return NSImage(cgImage: rendered, size: image.size)\n    }\n\n    private func enableLaunchAtLoginIfNeeded() {\n        guard SMAppService.mainApp.status != .enabled else { return }\n        do {\n            try SMAppService.mainApp.register()\n        } catch {\n            Logger().error(\n                \"Failed to register EXO for launch at login: \\(error.localizedDescription)\")\n        }\n    }\n\n}\n\n/// Helper for managing EXO's launch-at-login registration\nenum LaunchAtLoginHelper {\n    private static let logger = Logger(subsystem: \"io.exo.EXO\", category: \"LaunchAtLogin\")\n\n    /// Unregisters EXO from launching at login\n    static func disable() {\n        guard SMAppService.mainApp.status == .enabled else { return }\n        do {\n            try SMAppService.mainApp.unregister()\n            logger.info(\"Unregistered EXO from launch at login\")\n        } catch {\n            logger.error(\n                \"Failed to unregister EXO from launch at login: \\(error.localizedDescription, privacy: .public)\"\n            )\n        }\n    }\n}\n\nfinal class SparkleUpdater: NSObject, ObservableObject {\n    private let controller: SPUStandardUpdaterController\n    private let delegateProxy: ExoUpdaterDelegate\n    private let notificationDelegate = ExoNotificationDelegate()\n    private var periodicCheckTask: Task<Void, Never>?\n\n    init(processController: ExoProcessController) {\n        let proxy = ExoUpdaterDelegate(processController: processController)\n        delegateProxy = proxy\n        controller = SPUStandardUpdaterController(\n            startingUpdater: true,\n            updaterDelegate: proxy,\n            userDriverDelegate: nil\n        )\n        super.init()\n        let center = UNUserNotificationCenter.current()\n        center.delegate = notificationDelegate\n        center.requestAuthorization(options: [.alert, .sound]) { _, _ in }\n        controller.updater.automaticallyChecksForUpdates = true\n        controller.updater.automaticallyDownloadsUpdates = false\n        controller.updater.updateCheckInterval = 900  // 15 minutes\n        DispatchQueue.main.asyncAfter(deadline: .now() + 5) { [weak controller] in\n            controller?.updater.checkForUpdatesInBackground()\n        }\n        let updater = controller.updater\n        let intervalSeconds = max(60.0, controller.updater.updateCheckInterval)\n        let intervalNanos = UInt64(intervalSeconds * 1_000_000_000)\n        periodicCheckTask = Task {\n            while !Task.isCancelled {\n                try? await Task.sleep(nanoseconds: intervalNanos)\n                await MainActor.run {\n                    updater.checkForUpdatesInBackground()\n                }\n            }\n        }\n    }\n\n    deinit {\n        periodicCheckTask?.cancel()\n    }\n\n    @MainActor\n    func checkForUpdates() {\n        controller.checkForUpdates(nil)\n    }\n}\n\nprivate final class ExoUpdaterDelegate: NSObject, SPUUpdaterDelegate {\n    private weak var processController: ExoProcessController?\n\n    init(processController: ExoProcessController) {\n        self.processController = processController\n    }\n\n    nonisolated func updater(_ updater: SPUUpdater, didFindValidUpdate item: SUAppcastItem) {\n        showNotification(\n            title: \"Update available\",\n            body: \"EXO \\(item.displayVersionString) is ready to install.\"\n        )\n    }\n\n    nonisolated func updaterWillRelaunchApplication(_ updater: SPUUpdater) {\n        Task { @MainActor in\n            guard let controller = self.processController else { return }\n            controller.cancelPendingLaunch()\n            controller.stop()\n        }\n    }\n\n    nonisolated private func showNotification(title: String, body: String) {\n        let center = UNUserNotificationCenter.current()\n        let content = UNMutableNotificationContent()\n        content.title = title\n        content.body = body\n        let request = UNNotificationRequest(\n            identifier: \"exo-update-\\(UUID().uuidString)\",\n            content: content,\n            trigger: nil\n        )\n        center.add(request, withCompletionHandler: nil)\n    }\n}\n\nprivate final class ExoNotificationDelegate: NSObject, UNUserNotificationCenterDelegate {\n    func userNotificationCenter(\n        _ center: UNUserNotificationCenter,\n        willPresent notification: UNNotification,\n        withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) ->\n            Void\n    ) {\n        completionHandler([.banner, .list, .sound])\n    }\n}\n\n@MainActor\nprivate final class TerminationObserver {\n    private var token: NSObjectProtocol?\n\n    init(onTerminate: @escaping () -> Void) {\n        token = NotificationCenter.default.addObserver(\n            forName: NSApplication.willTerminateNotification,\n            object: nil,\n            queue: .main\n        ) { _ in\n            onTerminate()\n        }\n    }\n\n    deinit {\n        if let token {\n            NotificationCenter.default.removeObserver(token)\n        }\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/ExoProcessController.swift",
    "content": "import AppKit\nimport Combine\nimport Foundation\n\nprivate let customNamespaceKey = \"EXOCustomNamespace\"\nprivate let hfTokenKey = \"EXOHFToken\"\nprivate let enableImageModelsKey = \"EXOEnableImageModels\"\nprivate let offlineModeKey = \"EXOOfflineMode\"\nprivate let onboardingCompletedKey = \"EXOOnboardingCompleted\"\n\n@MainActor\nfinal class ExoProcessController: ObservableObject {\n    enum Status: Equatable {\n        case stopped\n        case starting\n        case running\n        case failed(message: String)\n\n        var displayText: String {\n            switch self {\n            case .stopped:\n                return \"Stopped\"\n            case .starting:\n                return \"Starting…\"\n            case .running:\n                return \"Running\"\n            case .failed:\n                return \"Failed\"\n            }\n        }\n    }\n\n    static let exoDirectoryURL: URL = {\n        URL(fileURLWithPath: NSHomeDirectory()).appendingPathComponent(\".exo\")\n    }()\n\n    @Published private(set) var status: Status = .stopped\n    @Published private(set) var lastError: String?\n    @Published private(set) var launchCountdownSeconds: Int?\n    @Published var customNamespace: String = {\n        return UserDefaults.standard.string(forKey: customNamespaceKey) ?? \"\"\n    }()\n    {\n        didSet {\n            UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)\n        }\n    }\n    @Published var hfToken: String = {\n        return UserDefaults.standard.string(forKey: hfTokenKey) ?? \"\"\n    }()\n    {\n        didSet {\n            UserDefaults.standard.set(hfToken, forKey: hfTokenKey)\n        }\n    }\n    @Published var enableImageModels: Bool = {\n        return UserDefaults.standard.bool(forKey: enableImageModelsKey)\n    }()\n    {\n        didSet {\n            UserDefaults.standard.set(enableImageModels, forKey: enableImageModelsKey)\n        }\n    }\n    @Published var offlineMode: Bool = {\n        return UserDefaults.standard.bool(forKey: offlineModeKey)\n    }()\n    {\n        didSet {\n            UserDefaults.standard.set(offlineMode, forKey: offlineModeKey)\n        }\n    }\n\n    /// Fires once when EXO transitions to `.running` for the very first time (fresh install).\n    @Published private(set) var isFirstLaunchReady = false\n\n    private var process: Process?\n    private var runtimeDirectoryURL: URL?\n    private var pendingLaunchTask: Task<Void, Never>?\n\n    func launchIfNeeded() {\n        guard process?.isRunning != true else { return }\n        launch()\n    }\n\n    func launch() {\n        do {\n            guard process?.isRunning != true else { return }\n            cancelPendingLaunch()\n            status = .starting\n            lastError = nil\n            let runtimeURL = try resolveRuntimeDirectory()\n            runtimeDirectoryURL = runtimeURL\n\n            let executableURL = runtimeURL.appendingPathComponent(\"exo\")\n\n            let child = Process()\n            child.executableURL = executableURL\n            let exoHomeURL = Self.exoDirectoryURL\n            try? FileManager.default.createDirectory(\n                at: exoHomeURL, withIntermediateDirectories: true\n            )\n            child.currentDirectoryURL = exoHomeURL\n            child.environment = makeEnvironment(for: runtimeURL)\n\n            child.standardOutput = FileHandle.nullDevice\n            child.standardError = FileHandle.nullDevice\n\n            child.terminationHandler = { [weak self] proc in\n                Task { @MainActor in\n                    guard let self else { return }\n                    self.process = nil\n                    switch self.status {\n                    case .stopped:\n                        break\n                    case .failed:\n                        break\n                    default:\n                        self.status = .failed(\n                            message: \"Exited with code \\(proc.terminationStatus)\"\n                        )\n                        self.lastError = \"Process exited with code \\(proc.terminationStatus)\"\n                    }\n                }\n            }\n\n            try child.run()\n            process = child\n            status = .running\n\n            // Show welcome popout on every launch\n            isFirstLaunchReady = true\n        } catch {\n            process = nil\n            status = .failed(message: \"Launch error\")\n            lastError = error.localizedDescription\n        }\n    }\n\n    func stop() {\n        guard let process else {\n            status = .stopped\n            return\n        }\n        process.terminationHandler = nil\n        status = .stopped\n\n        guard process.isRunning else {\n            self.process = nil\n            return\n        }\n\n        let proc = process\n        self.process = nil\n\n        Task.detached {\n            proc.interrupt()\n\n            for _ in 0..<50 {\n                if !proc.isRunning { return }\n                try? await Task.sleep(nanoseconds: 100_000_000)\n            }\n\n            if proc.isRunning {\n                proc.terminate()\n            }\n\n            for _ in 0..<30 {\n                if !proc.isRunning { return }\n                try? await Task.sleep(nanoseconds: 100_000_000)\n            }\n\n            if proc.isRunning {\n                kill(proc.processIdentifier, SIGKILL)\n            }\n        }\n    }\n\n    func restart() {\n        stop()\n        launch()\n    }\n\n    /// Mark onboarding as completed (user interacted with the welcome popout).\n    func markOnboardingCompleted() {\n        UserDefaults.standard.set(true, forKey: onboardingCompletedKey)\n    }\n\n    /// Reset onboarding so the welcome popout appears on next launch.\n    func resetOnboarding() {\n        UserDefaults.standard.removeObject(forKey: onboardingCompletedKey)\n        isFirstLaunchReady = false\n    }\n\n    func scheduleLaunch(after seconds: TimeInterval) {\n        cancelPendingLaunch()\n        let start = max(1, Int(ceil(seconds)))\n        pendingLaunchTask = Task { [weak self] in\n            guard let self else { return }\n            await MainActor.run {\n                self.launchCountdownSeconds = start\n            }\n            var remaining = start\n            while remaining > 0 {\n                try? await Task.sleep(nanoseconds: 1_000_000_000)\n                remaining -= 1\n                if Task.isCancelled { return }\n                await MainActor.run {\n                    if remaining > 0 {\n                        self.launchCountdownSeconds = remaining\n                    } else {\n                        self.launchCountdownSeconds = nil\n                        self.launchIfNeeded()\n                    }\n                }\n            }\n        }\n    }\n\n    func cancelPendingLaunch() {\n        pendingLaunchTask?.cancel()\n        pendingLaunchTask = nil\n        launchCountdownSeconds = nil\n    }\n\n    func revealRuntimeDirectory() {\n        guard let runtimeDirectoryURL else { return }\n        NSWorkspace.shared.activateFileViewerSelecting([runtimeDirectoryURL])\n    }\n\n    func statusTintColor() -> NSColor {\n        switch status {\n        case .running:\n            return .systemGreen\n        case .starting:\n            return .systemYellow\n        case .failed:\n            return .systemRed\n        case .stopped:\n            return .systemGray\n        }\n    }\n\n    private func resolveRuntimeDirectory() throws -> URL {\n        let fileManager = FileManager.default\n\n        if let override = ProcessInfo.processInfo.environment[\"EXO_RUNTIME_DIR\"] {\n            let url = URL(fileURLWithPath: override).standardizedFileURL\n            if fileManager.fileExists(atPath: url.path) {\n                return url\n            }\n        }\n\n        if let resourceRoot = Bundle.main.resourceURL {\n            let bundled = resourceRoot.appendingPathComponent(\"exo\", isDirectory: true)\n            if fileManager.fileExists(atPath: bundled.path) {\n                return bundled\n            }\n        }\n\n        let repoCandidate = URL(fileURLWithPath: fileManager.currentDirectoryPath)\n            .appendingPathComponent(\"dist/exo\", isDirectory: true)\n        if fileManager.fileExists(atPath: repoCandidate.path) {\n            return repoCandidate\n        }\n\n        throw RuntimeError(\"Unable to locate the packaged EXO runtime.\")\n    }\n\n    private func makeEnvironment(for runtimeURL: URL) -> [String: String] {\n        var environment = ProcessInfo.processInfo.environment\n        environment[\"EXO_RUNTIME_DIR\"] = runtimeURL.path\n        environment[\"EXO_LIBP2P_NAMESPACE\"] = computeNamespace()\n        if !hfToken.isEmpty {\n            environment[\"HF_TOKEN\"] = hfToken\n        }\n        if enableImageModels {\n            environment[\"EXO_ENABLE_IMAGE_MODELS\"] = \"true\"\n        }\n        if offlineMode {\n            environment[\"EXO_OFFLINE\"] = \"true\"\n        }\n\n        var paths: [String] = []\n        if let existing = environment[\"PATH\"], !existing.isEmpty {\n            paths = existing.split(separator: \":\").map(String.init)\n        }\n\n        let required = [\n            runtimeURL.path,\n            runtimeURL.appendingPathComponent(\"_internal\").path,\n            \"/opt/homebrew/bin\",\n            \"/usr/local/bin\",\n            \"/usr/bin\",\n            \"/bin\",\n            \"/usr/sbin\",\n            \"/sbin\",\n        ]\n\n        for entry in required.reversed() {\n            if !paths.contains(entry) {\n                paths.insert(entry, at: 0)\n            }\n        }\n\n        environment[\"PATH\"] = paths.joined(separator: \":\")\n        return environment\n    }\n\n    private func buildTag() -> String {\n        if let tag = Bundle.main.infoDictionary?[\"EXOBuildTag\"] as? String, !tag.isEmpty {\n            return tag\n        }\n        if let short = Bundle.main.infoDictionary?[\"CFBundleShortVersionString\"] as? String,\n            !short.isEmpty\n        {\n            return short\n        }\n        return \"dev\"\n    }\n\n    private func computeNamespace() -> String {\n        let base = buildTag()\n        let custom = customNamespace.trimmingCharacters(in: .whitespaces)\n        return custom.isEmpty ? base : custom\n    }\n}\n\nstruct RuntimeError: LocalizedError {\n    let message: String\n\n    init(_ message: String) {\n        self.message = message\n    }\n\n    var errorDescription: String? {\n        message\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Info.plist",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">\n<plist version=\"1.0\">\n<dict>\n\t<key>SUFeedURL</key>\n\t<string>https://assets.exolabs.net/appcast.xml</string>\n\t<key>EXOBuildTag</key>\n\t<string>$(EXO_BUILD_TAG)</string>\n\t<key>EXOBuildCommit</key>\n\t<string>$(EXO_BUILD_COMMIT)</string>\n\t<key>EXOBugReportPresignedUrlEndpoint</key>\n\t<string>$(EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT)</string>\n\t<key>NSLocalNetworkUsageDescription</key>\n\t<string>EXO needs local network access to discover and connect to other devices in your cluster for distributed AI inference.</string>\n\t<key>NSBonjourServices</key>\n\t<array>\n\t\t<string>_p2p._tcp</string>\n\t\t<string>_p2p._udp</string>\n\t\t<string>_libp2p._udp</string>\n\t</array>\n</dict>\n</plist>\n"
  },
  {
    "path": "app/EXO/EXO/Models/ClusterState.swift",
    "content": "import Foundation\n\n// MARK: - API payloads\n\nstruct ClusterState: Decodable {\n    let instances: [String: ClusterInstance]\n    let runners: [String: RunnerStatusSummary]\n    let tasks: [String: ClusterTask]\n    let topology: Topology?\n    let downloads: [String: [NodeDownloadStatus]]\n    let thunderboltBridgeCycles: [[String]]\n\n    // Granular node state (split from the old nodeProfiles)\n    let nodeIdentities: [String: NodeIdentity]\n    let nodeMemory: [String: MemoryInfo]\n    let nodeSystem: [String: SystemInfo]\n    let nodeThunderboltBridge: [String: ThunderboltBridgeStatus]\n    let nodeRdmaCtl: [String: NodeRdmaCtlStatus]\n\n    /// Computed property for backwards compatibility - merges granular state into NodeProfile\n    var nodeProfiles: [String: NodeProfile] {\n        var profiles: [String: NodeProfile] = [:]\n        let allNodeIds = Set(nodeIdentities.keys)\n            .union(nodeMemory.keys)\n            .union(nodeSystem.keys)\n        for nodeId in allNodeIds {\n            let identity = nodeIdentities[nodeId]\n            let memory = nodeMemory[nodeId]\n            let system = nodeSystem[nodeId]\n            profiles[nodeId] = NodeProfile(\n                modelId: identity?.modelId,\n                chipId: identity?.chipId,\n                friendlyName: identity?.friendlyName,\n                memory: memory,\n                system: system\n            )\n        }\n        return profiles\n    }\n\n    init(from decoder: Decoder) throws {\n        let container = try decoder.container(keyedBy: CodingKeys.self)\n        let rawInstances = try container.decode([String: TaggedInstance].self, forKey: .instances)\n        self.instances = rawInstances.mapValues(\\.instance)\n        self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)\n        let rawTasks =\n            try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]\n        self.tasks = rawTasks.compactMapValues(\\.task)\n        self.topology = try container.decodeIfPresent(Topology.self, forKey: .topology)\n        let rawDownloads =\n            try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)\n            ?? [:]\n        self.downloads = rawDownloads.mapValues { $0.compactMap(\\.status) }\n        self.thunderboltBridgeCycles =\n            try container.decodeIfPresent([[String]].self, forKey: .thunderboltBridgeCycles) ?? []\n\n        // Granular node state\n        self.nodeIdentities =\n            try container.decodeIfPresent([String: NodeIdentity].self, forKey: .nodeIdentities)\n            ?? [:]\n        self.nodeMemory =\n            try container.decodeIfPresent([String: MemoryInfo].self, forKey: .nodeMemory) ?? [:]\n        self.nodeSystem =\n            try container.decodeIfPresent([String: SystemInfo].self, forKey: .nodeSystem) ?? [:]\n        self.nodeThunderboltBridge =\n            try container.decodeIfPresent(\n                [String: ThunderboltBridgeStatus].self, forKey: .nodeThunderboltBridge\n            ) ?? [:]\n        self.nodeRdmaCtl =\n            try container.decodeIfPresent(\n                [String: NodeRdmaCtlStatus].self, forKey: .nodeRdmaCtl\n            ) ?? [:]\n    }\n\n    private enum CodingKeys: String, CodingKey {\n        case instances\n        case runners\n        case topology\n        case tasks\n        case downloads\n        case thunderboltBridgeCycles\n        case nodeIdentities\n        case nodeMemory\n        case nodeSystem\n        case nodeThunderboltBridge\n        case nodeRdmaCtl\n    }\n}\n\nprivate struct TaggedInstance: Decodable {\n    let instance: ClusterInstance\n\n    init(from decoder: Decoder) throws {\n        let container = try decoder.singleValueContainer()\n        let payloads = try container.decode([String: ClusterInstancePayload].self)\n        guard let entry = payloads.first else {\n            throw DecodingError.dataCorrupted(\n                DecodingError.Context(\n                    codingPath: decoder.codingPath, debugDescription: \"Empty instance payload\")\n            )\n        }\n        self.instance = ClusterInstance(\n            instanceId: entry.value.instanceId,\n            shardAssignments: entry.value.shardAssignments,\n            sharding: entry.key\n        )\n    }\n}\n\nprivate struct ClusterInstancePayload: Decodable {\n    let instanceId: String?\n    let shardAssignments: ShardAssignments\n}\n\nstruct ClusterInstance {\n    let instanceId: String?\n    let shardAssignments: ShardAssignments\n    let sharding: String\n}\n\nstruct ShardAssignments: Decodable {\n    let modelId: String\n    let nodeToRunner: [String: String]\n}\n\nstruct RunnerStatusSummary: Decodable {\n    let status: String\n    let errorMessage: String?\n\n    init(from decoder: Decoder) throws {\n        let container = try decoder.singleValueContainer()\n        let payloads = try container.decode([String: RunnerStatusDetail].self)\n        guard let entry = payloads.first else {\n            throw DecodingError.dataCorrupted(\n                DecodingError.Context(\n                    codingPath: decoder.codingPath, debugDescription: \"Empty runner status payload\")\n            )\n        }\n        self.status = entry.key\n        self.errorMessage = entry.value.errorMessage\n    }\n}\n\nstruct RunnerStatusDetail: Decodable {\n    let errorMessage: String?\n}\n\nstruct NodeProfile: Decodable {\n    let modelId: String?\n    let chipId: String?\n    let friendlyName: String?\n    let memory: MemoryInfo?\n    let system: SystemInfo?\n}\n\nstruct NodeIdentity: Decodable {\n    let modelId: String?\n    let chipId: String?\n    let friendlyName: String?\n}\n\nstruct ThunderboltBridgeStatus: Decodable {\n    let enabled: Bool\n    let exists: Bool\n    let serviceName: String?\n}\n\nstruct NodeRdmaCtlStatus: Decodable {\n    let enabled: Bool\n}\n\nstruct MemoryInfo: Decodable {\n    let ramTotal: MemoryValue?\n    let ramAvailable: MemoryValue?\n}\n\nstruct MemoryValue: Decodable {\n    let inBytes: Int64?\n}\n\nstruct SystemInfo: Decodable {\n    let gpuUsage: Double?\n    let temp: Double?\n    let sysPower: Double?\n    let pcpuUsage: Double?\n    let ecpuUsage: Double?\n}\n\nstruct Topology: Decodable {\n    /// Node IDs in the topology\n    let nodes: [String]\n    /// Flattened list of connections (source -> sink pairs)\n    let connections: [TopologyConnection]\n\n    init(from decoder: Decoder) throws {\n        let container = try decoder.container(keyedBy: CodingKeys.self)\n        self.nodes = try container.decodeIfPresent([String].self, forKey: .nodes) ?? []\n\n        // Connections come as nested map: { source: { sink: [edges] } }\n        // We flatten to array of (source, sink) pairs\n        var flatConnections: [TopologyConnection] = []\n        if let nested = try container.decodeIfPresent(\n            [String: [String: [AnyCodable]]].self, forKey: .connections\n        ) {\n            for (source, sinks) in nested {\n                for sink in sinks.keys {\n                    flatConnections.append(\n                        TopologyConnection(localNodeId: source, sendBackNodeId: sink))\n                }\n            }\n        }\n        self.connections = flatConnections\n    }\n\n    private enum CodingKeys: String, CodingKey {\n        case nodes\n        case connections\n    }\n}\n\n/// Placeholder for decoding arbitrary JSON values we don't need to inspect\nprivate struct AnyCodable: Decodable {\n    init(from decoder: Decoder) throws {\n        // Just consume the value without storing it\n        _ = try? decoder.singleValueContainer().decode(Bool.self)\n        _ = try? decoder.singleValueContainer().decode(Int.self)\n        _ = try? decoder.singleValueContainer().decode(Double.self)\n        _ = try? decoder.singleValueContainer().decode(String.self)\n        _ = try? decoder.singleValueContainer().decode([AnyCodable].self)\n        _ = try? decoder.singleValueContainer().decode([String: AnyCodable].self)\n    }\n}\n\nstruct TopologyConnection {\n    let localNodeId: String\n    let sendBackNodeId: String\n}\n\n// MARK: - Downloads\n\nprivate struct TaggedNodeDownload: Decodable {\n    let status: NodeDownloadStatus?\n\n    init(from decoder: Decoder) throws {\n        let container = try decoder.singleValueContainer()\n        let payloads = try container.decode([String: NodeDownloadPayload].self)\n        guard let entry = payloads.first else {\n            status = nil\n            return\n        }\n        status = NodeDownloadStatus(statusKey: entry.key, payload: entry.value)\n    }\n}\n\nstruct NodeDownloadPayload: Decodable {\n    let nodeId: String?\n    let downloadProgress: DownloadProgress?\n}\n\nstruct NodeDownloadStatus {\n    let nodeId: String\n    let progress: DownloadProgress?\n\n    init?(statusKey: String, payload: NodeDownloadPayload) {\n        guard let nodeId = payload.nodeId else { return nil }\n        self.nodeId = nodeId\n        self.progress = statusKey == \"DownloadOngoing\" ? payload.downloadProgress : nil\n    }\n}\n\nstruct DownloadProgress: Decodable {\n    let totalBytes: ByteValue\n    let downloadedBytes: ByteValue\n    let speed: Double?\n    let etaMs: Int64?\n    let completedFiles: Int?\n    let totalFiles: Int?\n    let files: [String: FileDownloadProgress]?\n}\n\nstruct ByteValue: Decodable {\n    let inBytes: Int64\n}\n\nstruct FileDownloadProgress: Decodable {\n    let totalBytes: ByteValue\n    let downloadedBytes: ByteValue\n    let speed: Double?\n    let etaMs: Int64?\n}\n\n// MARK: - Tasks\n\nstruct ClusterTask {\n    enum Kind {\n        case chatCompletion\n    }\n\n    let id: String\n    let status: TaskStatus\n    let instanceId: String?\n    let kind: Kind\n    let modelName: String?\n    let promptPreview: String?\n    let errorMessage: String?\n    let parameters: TextGenerationTaskParameters?\n\n    var sortPriority: Int {\n        switch status {\n        case .running:\n            return 0\n        case .pending:\n            return 1\n        case .complete:\n            return 2\n        case .failed:\n            return 3\n        case .unknown:\n            return 4\n        }\n    }\n}\n\nprivate struct TaggedTask: Decodable {\n    let task: ClusterTask?\n\n    init(from decoder: Decoder) throws {\n        let container = try decoder.singleValueContainer()\n        let payloads = try container.decode([String: ClusterTaskPayload].self)\n        guard let entry = payloads.first else {\n            task = nil\n            return\n        }\n        task = ClusterTask(kindKey: entry.key, payload: entry.value)\n    }\n}\n\nstruct ClusterTaskPayload: Decodable {\n    let taskId: String?\n    let taskStatus: TaskStatus?\n    let instanceId: String?\n    let commandId: String?\n    let taskParams: TextGenerationTaskParameters?\n    let errorType: String?\n    let errorMessage: String?\n}\n\nstruct TextGenerationTaskParameters: Decodable, Equatable {\n    let model: String?\n    let messages: [ChatCompletionMessage]?\n    let maxTokens: Int?\n    let stream: Bool?\n    let temperature: Double?\n    let topP: Double?\n\n    private enum CodingKeys: String, CodingKey {\n        case model\n        case messages\n        case maxTokens\n        case stream\n        case temperature\n        case topP\n    }\n\n    func promptPreview() -> String? {\n        guard let messages else { return nil }\n        if let userMessage = messages.last(where: {\n            $0.role?.lowercased() == \"user\" && ($0.content?.isEmpty == false)\n        }) {\n            return userMessage.content\n        }\n        return messages.last?.content\n    }\n\n}\n\nstruct ChatCompletionMessage: Decodable, Equatable {\n    let role: String?\n    let content: String?\n}\n\nextension ClusterTask {\n    init?(kindKey: String, payload: ClusterTaskPayload) {\n        guard let id = payload.taskId else { return nil }\n        let status = payload.taskStatus ?? .unknown\n        switch kindKey {\n        case \"TextGeneration\":\n            self.init(\n                id: id,\n                status: status,\n                instanceId: payload.instanceId,\n                kind: .chatCompletion,\n                modelName: payload.taskParams?.model,\n                promptPreview: payload.taskParams?.promptPreview(),\n                errorMessage: payload.errorMessage,\n                parameters: payload.taskParams\n            )\n        default:\n            return nil\n        }\n    }\n}\n\nenum TaskStatus: String, Decodable {\n    case pending = \"Pending\"\n    case running = \"Running\"\n    case complete = \"Complete\"\n    case failed = \"Failed\"\n    case unknown\n\n    init(from decoder: Decoder) throws {\n        let container = try decoder.singleValueContainer()\n        let value = try container.decode(String.self)\n        self = TaskStatus(rawValue: value) ?? .unknown\n    }\n\n    var displayLabel: String {\n        switch self {\n        case .pending, .running, .complete, .failed:\n            return rawValue\n        case .unknown:\n            return \"Unknown\"\n        }\n    }\n}\n\n// MARK: - Derived summaries\n\nstruct ClusterOverview {\n    let totalRam: Double\n    let usedRam: Double\n    let nodeCount: Int\n    let instanceCount: Int\n}\n\nstruct NodeSummary: Identifiable {\n    let id: String\n    let friendlyName: String\n    let model: String\n    let usedRamGB: Double\n    let totalRamGB: Double\n    let gpuUsagePercent: Double\n    let temperatureCelsius: Double\n}\n\nstruct InstanceSummary: Identifiable {\n    let id: String\n    let modelId: String\n    let nodeCount: Int\n    let statusText: String\n}\n\nextension ClusterState {\n    func overview() -> ClusterOverview {\n        var total: Double = 0\n        var available: Double = 0\n        for profile in nodeProfiles.values {\n            if let totalBytes = profile.memory?.ramTotal?.inBytes {\n                total += Double(totalBytes)\n            }\n            if let availableBytes = profile.memory?.ramAvailable?.inBytes {\n                available += Double(availableBytes)\n            }\n        }\n        let totalGB = total / 1_073_741_824.0\n        let usedGB = max(total - available, 0) / 1_073_741_824.0\n        return ClusterOverview(\n            totalRam: totalGB,\n            usedRam: usedGB,\n            nodeCount: nodeProfiles.count,\n            instanceCount: instances.count\n        )\n    }\n\n    func availableModels() -> [ModelOption] { [] }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Preview Content/Preview Assets.xcassets/Contents.json",
    "content": "{\n  \"info\" : {\n    \"author\" : \"xcode\",\n    \"version\" : 1\n  }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Services/BugReportService.swift",
    "content": "import Foundation\n\nstruct BugReportOutcome: Equatable {\n    let success: Bool\n    let message: String\n}\n\nenum BugReportError: LocalizedError {\n    case invalidEndpoint\n    case presignedUrlFailed(String)\n    case uploadFailed(String)\n    case collectFailed(String)\n\n    var errorDescription: String? {\n        switch self {\n        case .invalidEndpoint:\n            return \"Bug report endpoint is invalid.\"\n        case .presignedUrlFailed(let message):\n            return \"Failed to get presigned URLs: \\(message)\"\n        case .uploadFailed(let message):\n            return \"Bug report upload failed: \\(message)\"\n        case .collectFailed(let message):\n            return \"Bug report collection failed: \\(message)\"\n        }\n    }\n}\n\nstruct BugReportService {\n    private struct PresignedUrlsRequest: Codable {\n        let keys: [String]\n    }\n\n    private struct PresignedUrlsResponse: Codable {\n        let urls: [String: String]\n        let expiresIn: Int?\n    }\n\n    func sendReport(\n        baseURL: URL = URL(string: \"http://127.0.0.1:52415\")!,\n        now: Date = Date(),\n        isManual: Bool = false,\n        userDescription: String? = nil\n    ) async throws -> BugReportOutcome {\n        let timestamp = Self.runTimestampString(now)\n        let dayPrefix = Self.dayPrefixString(now)\n        let prefix = \"reports/\\(dayPrefix)/\\(timestamp)/\"\n\n        let logFiles = readAllLogs()\n        let ifconfigText = try await captureIfconfig()\n        let hostName = Host.current().localizedName ?? \"unknown\"\n        let debugInfo = readDebugInfo()\n\n        let stateData = try await fetch(url: baseURL.appendingPathComponent(\"state\"))\n\n        // Extract cluster TB bridge status from exo state\n        let clusterTbBridgeStatus = extractClusterTbBridgeStatus(from: stateData)\n\n        let reportJSON = makeReportJson(\n            timestamp: timestamp,\n            hostName: hostName,\n            ifconfig: ifconfigText,\n            debugInfo: debugInfo,\n            isManual: isManual,\n            clusterTbBridgeStatus: clusterTbBridgeStatus,\n            userDescription: userDescription\n        )\n\n        let eventLogFiles = readAllEventLogs()\n\n        var uploads: [(path: String, data: Data?)] = logFiles.map { (path, data) in\n            (\"\\(prefix)\\(path)\", data)\n        }\n        uploads.append(\n            contentsOf: eventLogFiles.map { (path, data) in\n                (\"\\(prefix)\\(path)\", data as Data?)\n            })\n        uploads.append(contentsOf: [\n            (\"\\(prefix)state.json\", stateData),\n            (\"\\(prefix)report.json\", reportJSON),\n        ])\n\n        let uploadItems: [(key: String, body: Data)] = uploads.compactMap { item in\n            guard let body = item.data else { return nil }\n            return (key: item.path, body: body)\n        }\n\n        guard !uploadItems.isEmpty else {\n            return BugReportOutcome(success: false, message: \"No data to upload\")\n        }\n\n        let presignedUrls = try await fetchPresignedUploadUrls(keys: uploadItems.map(\\.key))\n        for item in uploadItems {\n            guard let urlString = presignedUrls[item.key], let url = URL(string: urlString) else {\n                throw BugReportError.uploadFailed(\"Missing presigned URL for \\(item.key)\")\n            }\n            try await uploadToPresignedUrl(url: url, body: item.body)\n        }\n\n        return BugReportOutcome(\n            success: true, message: \"Bug Report sent. Thank you for helping to improve EXO 1.0.\")\n    }\n\n    private static func dayPrefixString(_ date: Date) -> String {\n        var calendar = Calendar(identifier: .gregorian)\n        calendar.timeZone = TimeZone(secondsFromGMT: 0) ?? .current\n        let components = calendar.dateComponents([.year, .month, .day], from: date)\n        let year = components.year ?? 0\n        let month = components.month ?? 0\n        let day = components.day ?? 0\n        return String(format: \"%04d/%02d/%02d\", year, month, day)\n    }\n\n    private static func runTimestampString(_ date: Date) -> String {\n        let formatter = DateFormatter()\n        formatter.locale = Locale(identifier: \"en_US_POSIX\")\n        formatter.timeZone = TimeZone(secondsFromGMT: 0) ?? .current\n        formatter.dateFormat = \"yyyy-MM-dd'T'HHmmss.SSS'Z'\"\n        return formatter.string(from: date)\n    }\n\n    private func fetchPresignedUploadUrls(keys: [String], bundle: Bundle = .main) async throws\n        -> [String: String]\n    {\n        guard\n            let endpointString = bundle.infoDictionary?[\"EXOBugReportPresignedUrlEndpoint\"]\n                as? String\n        else {\n            throw BugReportError.invalidEndpoint\n        }\n        let trimmedEndpointString = endpointString.trimmingCharacters(in: .whitespacesAndNewlines)\n        guard !trimmedEndpointString.isEmpty, let endpoint = URL(string: trimmedEndpointString)\n        else {\n            throw BugReportError.invalidEndpoint\n        }\n\n        var request = URLRequest(url: endpoint)\n        request.httpMethod = \"POST\"\n        request.timeoutInterval = 10\n        request.setValue(\"application/json\", forHTTPHeaderField: \"Content-Type\")\n\n        let encoder = JSONEncoder()\n        request.httpBody = try encoder.encode(PresignedUrlsRequest(keys: keys))\n\n        let (data, response) = try await URLSession.shared.data(for: request)\n        guard let http = response as? HTTPURLResponse else {\n            throw BugReportError.presignedUrlFailed(\"Non-HTTP response\")\n        }\n        guard (200..<300).contains(http.statusCode) else {\n            throw BugReportError.presignedUrlFailed(\"HTTP status \\(http.statusCode)\")\n        }\n\n        let decoder = JSONDecoder()\n        let decoded = try decoder.decode(PresignedUrlsResponse.self, from: data)\n        return decoded.urls\n    }\n\n    private func readAllLogs() -> [(path: String, data: Data)] {\n        let dir = URL(fileURLWithPath: NSHomeDirectory())\n            .appendingPathComponent(\".exo\")\n            .appendingPathComponent(\"exo_log\")\n        var results: [(path: String, data: Data)] = []\n\n        let contents = (try? FileManager.default.contentsOfDirectory(atPath: dir.path)) ?? []\n        for name in contents {\n            if let data = try? Data(contentsOf: dir.appendingPathComponent(name)) {\n                results.append((\"exo_log/\\(name)\", data))\n            }\n        }\n\n        return results\n    }\n\n    private func readAllEventLogs() -> [(path: String, data: Data)] {\n        let eventLogDir = URL(fileURLWithPath: NSHomeDirectory())\n            .appendingPathComponent(\".exo\")\n            .appendingPathComponent(\"event_log\")\n        var results: [(path: String, data: Data)] = []\n\n        for subdir in [\"master\", \"api\"] {\n            let dir = eventLogDir.appendingPathComponent(subdir)\n            let contents =\n                (try? FileManager.default.contentsOfDirectory(atPath: dir.path)) ?? []\n            for name in contents where name.hasPrefix(\"events.\") {\n                if let data = try? Data(contentsOf: dir.appendingPathComponent(name)) {\n                    results.append((\"event_log/\\(subdir)/\\(name)\", data))\n                }\n            }\n        }\n\n        return results\n    }\n\n    private func captureIfconfig() async throws -> String {\n        let result = runCommand([\"/sbin/ifconfig\"])\n        guard result.exitCode == 0 else {\n            throw BugReportError.collectFailed(\n                result.error.isEmpty ? \"ifconfig failed\" : result.error)\n        }\n        return result.output\n    }\n\n    private func readDebugInfo() -> DebugInfo {\n        DebugInfo(\n            thunderboltBridgeDisabled: readThunderboltBridgeDisabled(),\n            interfaces: readInterfaces(),\n            rdma: readRDMADebugInfo()\n        )\n    }\n\n    private func readRDMADebugInfo() -> DebugInfo.RDMADebugInfo {\n        DebugInfo.RDMADebugInfo(\n            rdmaCtlStatus: safeRunCommand([\"/usr/bin/rdma_ctl\", \"status\"]),\n            ibvDevices: safeRunCommand([\"/usr/bin/ibv_devices\"]),\n            ibvDevinfo: safeRunCommand([\"/usr/bin/ibv_devinfo\"])\n        )\n    }\n\n    private func readThunderboltBridgeDisabled() -> Bool? {\n        // Dynamically find the Thunderbolt Bridge service (don't assume the name)\n        guard let serviceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName() else {\n            // No bridge containing Thunderbolt interfaces exists\n            return nil\n        }\n\n        guard let isEnabled = ThunderboltBridgeDetector.isServiceEnabled(serviceName: serviceName)\n        else {\n            return nil\n        }\n\n        // Return true if disabled, false if enabled\n        return !isEnabled\n    }\n\n    private func readInterfaces() -> [DebugInfo.InterfaceStatus] {\n        (0...7).map { \"en\\($0)\" }.map { iface in\n            let result = runCommand([\"/sbin/ifconfig\", iface])\n            guard result.exitCode == 0 else {\n                return DebugInfo.InterfaceStatus(name: iface, ip: nil)\n            }\n            let ip = firstInet(from: result.output)\n            return DebugInfo.InterfaceStatus(name: iface, ip: ip)\n        }\n    }\n\n    private func firstInet(from ifconfigOutput: String) -> String? {\n        for line in ifconfigOutput.split(separator: \"\\n\") {\n            let trimmed = line.trimmingCharacters(in: .whitespaces)\n            guard trimmed.hasPrefix(\"inet \") else { continue }\n            let parts = trimmed.split(separator: \" \")\n            if parts.count >= 2 {\n                let candidate = String(parts[1])\n                if candidate != \"127.0.0.1\" {\n                    return candidate\n                }\n            }\n        }\n        return nil\n    }\n\n    private func fetch(url: URL) async throws -> Data? {\n        var request = URLRequest(url: url)\n        request.timeoutInterval = 5\n        do {\n            let (data, response) = try await URLSession.shared.data(for: request)\n            guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode)\n            else {\n                return nil\n            }\n            return data\n        } catch {\n            return nil\n        }\n    }\n\n    private func uploadToPresignedUrl(url: URL, body: Data) async throws {\n        let maxAttempts = 2\n        var lastError: Error?\n\n        for attempt in 1...maxAttempts {\n            do {\n                var request = URLRequest(url: url)\n                request.httpMethod = \"PUT\"\n                request.httpBody = body\n                request.timeoutInterval = 30\n\n                let (_, response) = try await URLSession.shared.data(for: request)\n                guard let http = response as? HTTPURLResponse else {\n                    throw BugReportError.uploadFailed(\"Non-HTTP response\")\n                }\n                guard (200..<300).contains(http.statusCode) else {\n                    throw BugReportError.uploadFailed(\"HTTP status \\(http.statusCode)\")\n                }\n                return\n            } catch {\n                lastError = error\n                if attempt < maxAttempts {\n                    try await Task.sleep(nanoseconds: 400_000_000)\n                }\n            }\n        }\n\n        throw BugReportError.uploadFailed(lastError?.localizedDescription ?? \"Unknown error\")\n    }\n\n    private func makeReportJson(\n        timestamp: String,\n        hostName: String,\n        ifconfig: String,\n        debugInfo: DebugInfo,\n        isManual: Bool,\n        clusterTbBridgeStatus: [[String: Any]]?,\n        userDescription: String? = nil\n    ) -> Data? {\n        let system = readSystemMetadata()\n        let exo = readExoMetadata()\n        var payload: [String: Any] = [\n            \"timestamp\": timestamp,\n            \"host\": hostName,\n            \"ifconfig\": ifconfig,\n            \"debug\": debugInfo.toDictionary(),\n            \"system\": system,\n            \"exo_version\": exo.version as Any,\n            \"exo_commit\": exo.commit as Any,\n            \"report_type\": isManual ? \"manual\" : \"automated\",\n        ]\n        if let tbStatus = clusterTbBridgeStatus {\n            payload[\"cluster_thunderbolt_bridge\"] = tbStatus\n        }\n        if let desc = userDescription, !desc.isEmpty {\n            payload[\"user_description\"] = desc\n        }\n        return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])\n    }\n\n    /// Extracts cluster-wide Thunderbolt Bridge status from exo state JSON\n    private func extractClusterTbBridgeStatus(from stateData: Data?) -> [[String: Any]]? {\n        guard let data = stateData,\n            let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any],\n            let nodeThunderboltBridge = json[\"node_thunderbolt_bridge\"] as? [String: [String: Any]]\n        else {\n            return nil\n        }\n\n        var result: [[String: Any]] = []\n        for (nodeId, status) in nodeThunderboltBridge {\n            var entry: [String: Any] = [\"node_id\": nodeId]\n            if let enabled = status[\"enabled\"] as? Bool {\n                entry[\"enabled\"] = enabled\n            }\n            if let exists = status[\"exists\"] as? Bool {\n                entry[\"exists\"] = exists\n            }\n            if let serviceName = status[\"service_name\"] as? String {\n                entry[\"service_name\"] = serviceName\n            }\n            result.append(entry)\n        }\n        return result.isEmpty ? nil : result\n    }\n\n    private func readSystemMetadata() -> [String: Any] {\n        let hostname = safeRunCommand([\"/bin/hostname\"])\n        let computerName = safeRunCommand([\"/usr/sbin/scutil\", \"--get\", \"ComputerName\"])\n        let localHostName = safeRunCommand([\"/usr/sbin/scutil\", \"--get\", \"LocalHostName\"])\n        let hostNameCommand = safeRunCommand([\"/usr/sbin/scutil\", \"--get\", \"HostName\"])\n        let hardwareModel = safeRunCommand([\"/usr/sbin/sysctl\", \"-n\", \"hw.model\"])\n        let hardwareProfile = safeRunCommand([\"/usr/sbin/system_profiler\", \"SPHardwareDataType\"])\n        let hardwareUUID = hardwareProfile.flatMap(extractHardwareUUID)\n\n        let osVersion = safeRunCommand([\"/usr/bin/sw_vers\", \"-productVersion\"])\n        let osBuild = safeRunCommand([\"/usr/bin/sw_vers\", \"-buildVersion\"])\n        let kernel = safeRunCommand([\"/usr/bin/uname\", \"-srv\"])\n        let arch = safeRunCommand([\"/usr/bin/uname\", \"-m\"])\n\n        let routeInfo = safeRunCommand([\"/sbin/route\", \"-n\", \"get\", \"default\"])\n        let defaultInterface = routeInfo.flatMap(parseDefaultInterface)\n        let defaultIP = defaultInterface.flatMap { iface in\n            safeRunCommand([\"/usr/sbin/ipconfig\", \"getifaddr\", iface])\n        }\n        let defaultMac = defaultInterface.flatMap { iface in\n            safeRunCommand([\"/sbin/ifconfig\", iface]).flatMap(parseEtherAddress)\n        }\n\n        let user = safeRunCommand([\"/usr/bin/whoami\"])\n        let consoleUser = safeRunCommand([\"/usr/bin/stat\", \"-f%Su\", \"/dev/console\"])\n        let uptime = safeRunCommand([\"/usr/bin/uptime\"])\n        let diskRoot = safeRunCommand([\n            \"/bin/sh\", \"-c\", \"/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'\",\n        ])\n\n        let interfacesList = safeRunCommand([\"/usr/sbin/ipconfig\", \"getiflist\"])\n        let interfacesAndIPs =\n            interfacesList?\n            .split(whereSeparator: { $0 == \" \" || $0 == \"\\n\" })\n            .compactMap { iface -> [String: Any]? in\n                let name = String(iface)\n                guard let ip = safeRunCommand([\"/usr/sbin/ipconfig\", \"getifaddr\", name]) else {\n                    return nil\n                }\n                return [\"name\": name, \"ip\": ip]\n            } ?? []\n\n        let wifiSSID: String?\n        let airportPath =\n            \"/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport\"\n        if FileManager.default.isExecutableFile(atPath: airportPath) {\n            wifiSSID = safeRunCommand([airportPath, \"-I\"]).flatMap(parseWifiSSID)\n        } else {\n            wifiSSID = nil\n        }\n\n        return [\n            \"hostname\": hostname as Any,\n            \"computer_name\": computerName as Any,\n            \"local_hostname\": localHostName as Any,\n            \"host_name\": hostNameCommand as Any,\n            \"hardware_model\": hardwareModel as Any,\n            \"hardware_profile\": hardwareProfile as Any,\n            \"hardware_uuid\": hardwareUUID as Any,\n            \"os_version\": osVersion as Any,\n            \"os_build\": osBuild as Any,\n            \"kernel\": kernel as Any,\n            \"arch\": arch as Any,\n            \"default_interface\": defaultInterface as Any,\n            \"default_ip\": defaultIP as Any,\n            \"default_mac\": defaultMac as Any,\n            \"user\": user as Any,\n            \"console_user\": consoleUser as Any,\n            \"uptime\": uptime as Any,\n            \"disk_root\": diskRoot as Any,\n            \"interfaces_and_ips\": interfacesAndIPs,\n            \"ipconfig_getiflist\": interfacesList as Any,\n            \"wifi_ssid\": wifiSSID as Any,\n        ]\n    }\n\n    private func readExoMetadata(bundle: Bundle = .main) -> (version: String?, commit: String?) {\n        let info = bundle.infoDictionary ?? [:]\n        let tag = info[\"EXOBuildTag\"] as? String\n        let short = info[\"CFBundleShortVersionString\"] as? String\n        let version = [tag, short]\n            .compactMap { $0?.trimmingCharacters(in: .whitespacesAndNewlines) }\n            .first { !$0.isEmpty }\n        let commit = (info[\"EXOBuildCommit\"] as? String)?\n            .trimmingCharacters(in: .whitespacesAndNewlines)\n        let normalizedCommit = (commit?.isEmpty == true) ? nil : commit\n        return (version: version, commit: normalizedCommit)\n    }\n\n    private func safeRunCommand(_ arguments: [String]) -> String? {\n        let result = runCommand(arguments)\n        guard result.exitCode == 0 else { return nil }\n        let trimmed = result.output.trimmingCharacters(in: .whitespacesAndNewlines)\n        return trimmed.isEmpty ? nil : trimmed\n    }\n\n    private func extractHardwareUUID(from hardwareProfile: String) -> String? {\n        hardwareProfile\n            .split(separator: \"\\n\")\n            .first { $0.contains(\"Hardware UUID\") }?\n            .split(separator: \":\")\n            .dropFirst()\n            .joined(separator: \":\")\n            .trimmingCharacters(in: .whitespaces)\n    }\n\n    private func parseDefaultInterface(from routeOutput: String) -> String? {\n        for line in routeOutput.split(separator: \"\\n\") {\n            let trimmed = line.trimmingCharacters(in: .whitespaces)\n            if trimmed.hasPrefix(\"interface: \") {\n                return trimmed.replacingOccurrences(of: \"interface: \", with: \"\")\n            }\n        }\n        return nil\n    }\n\n    private func parseEtherAddress(from ifconfigOutput: String) -> String? {\n        for line in ifconfigOutput.split(separator: \"\\n\") {\n            let trimmed = line.trimmingCharacters(in: .whitespaces)\n            if trimmed.hasPrefix(\"ether \") {\n                return trimmed.replacingOccurrences(of: \"ether \", with: \"\")\n            }\n        }\n        return nil\n    }\n\n    private func parseWifiSSID(from airportOutput: String) -> String? {\n        for line in airportOutput.split(separator: \"\\n\") {\n            let trimmed = line.trimmingCharacters(in: .whitespaces)\n            if trimmed.hasPrefix(\"SSID:\") {\n                return trimmed.replacingOccurrences(of: \"SSID:\", with: \"\").trimmingCharacters(\n                    in: .whitespaces)\n            }\n        }\n        return nil\n    }\n\n    private func runCommand(_ arguments: [String]) -> CommandResult {\n        let process = Process()\n        process.launchPath = arguments.first\n        process.arguments = Array(arguments.dropFirst())\n\n        let stdout = Pipe()\n        let stderr = Pipe()\n        process.standardOutput = stdout\n        process.standardError = stderr\n\n        do {\n            try process.run()\n        } catch {\n            return CommandResult(exitCode: -1, output: \"\", error: error.localizedDescription)\n        }\n        process.waitUntilExit()\n\n        let outputData = stdout.fileHandleForReading.readDataToEndOfFile()\n        let errorData = stderr.fileHandleForReading.readDataToEndOfFile()\n\n        return CommandResult(\n            exitCode: process.terminationStatus,\n            output: String(decoding: outputData, as: UTF8.self),\n            error: String(decoding: errorData, as: UTF8.self)\n        )\n    }\n}\n\nprivate struct DebugInfo {\n    let thunderboltBridgeDisabled: Bool?\n    let interfaces: [InterfaceStatus]\n    let rdma: RDMADebugInfo\n\n    struct InterfaceStatus {\n        let name: String\n        let ip: String?\n\n        func toDictionary() -> [String: Any] {\n            [\n                \"name\": name,\n                \"ip\": ip as Any,\n            ]\n        }\n    }\n\n    struct RDMADebugInfo {\n        let rdmaCtlStatus: String?\n        let ibvDevices: String?\n        let ibvDevinfo: String?\n\n        func toDictionary() -> [String: Any] {\n            [\n                \"rdma_ctl_status\": rdmaCtlStatus as Any,\n                \"ibv_devices\": ibvDevices as Any,\n                \"ibv_devinfo\": ibvDevinfo as Any,\n            ]\n        }\n    }\n\n    func toDictionary() -> [String: Any] {\n        [\n            \"thunderbolt_bridge_disabled\": thunderboltBridgeDisabled as Any,\n            \"interfaces\": interfaces.map { $0.toDictionary() },\n            \"rdma\": rdma.toDictionary(),\n        ]\n    }\n}\n\nprivate struct CommandResult {\n    let exitCode: Int32\n    let output: String\n    let error: String\n}\n"
  },
  {
    "path": "app/EXO/EXO/Services/ClusterStateService.swift",
    "content": "import Combine\nimport Foundation\n\n@MainActor\nfinal class ClusterStateService: ObservableObject {\n    @Published private(set) var latestSnapshot: ClusterState?\n    @Published private(set) var lastError: String?\n    @Published private(set) var lastActionMessage: String?\n    @Published private(set) var modelOptions: [ModelOption] = []\n    @Published private(set) var localNodeId: String?\n\n    private var timer: Timer?\n    private let decoder: JSONDecoder\n    private let session: URLSession\n    private let baseURL: URL\n    private let endpoint: URL\n\n    init(\n        baseURL: URL = URL(string: \"http://127.0.0.1:52415\")!,\n        session: URLSession = .shared\n    ) {\n        self.baseURL = baseURL\n        self.endpoint = baseURL.appendingPathComponent(\"state\")\n        self.session = session\n        let decoder = JSONDecoder()\n        decoder.keyDecodingStrategy = .convertFromSnakeCase\n        self.decoder = decoder\n    }\n\n    func startPolling(interval: TimeInterval = 0.5) {\n        stopPolling()\n        Task {\n            await fetchLocalNodeId()\n            await fetchModels()\n            await fetchSnapshot()\n        }\n        timer = Timer.scheduledTimer(withTimeInterval: interval, repeats: true) { [weak self] _ in\n            Task { await self?.fetchSnapshot() }\n        }\n    }\n\n    func stopPolling() {\n        timer?.invalidate()\n        timer = nil\n    }\n\n    func resetTransientState() {\n        latestSnapshot = nil\n        lastError = nil\n        lastActionMessage = nil\n        localNodeId = nil\n    }\n\n    private func fetchLocalNodeId() async {\n        do {\n            let url = baseURL.appendingPathComponent(\"node_id\")\n            var request = URLRequest(url: url)\n            request.cachePolicy = .reloadIgnoringLocalCacheData\n            let (data, response) = try await session.data(for: request)\n            guard let httpResponse = response as? HTTPURLResponse,\n                (200..<300).contains(httpResponse.statusCode)\n            else {\n                return\n            }\n            if let nodeId = try? decoder.decode(String.self, from: data) {\n                localNodeId = nodeId\n            }\n        } catch {\n            // Silently ignore - localNodeId will remain nil and retry on next poll\n        }\n    }\n\n    private func fetchSnapshot() async {\n        // Retry fetching local node ID if not yet set\n        if localNodeId == nil {\n            await fetchLocalNodeId()\n        }\n        do {\n            var request = URLRequest(url: endpoint)\n            request.cachePolicy = .reloadIgnoringLocalCacheData\n            let (data, response) = try await session.data(for: request)\n            guard let httpResponse = response as? HTTPURLResponse else {\n                throw URLError(.badServerResponse)\n            }\n            guard (200..<300).contains(httpResponse.statusCode) else {\n                throw URLError(.badServerResponse)\n            }\n            let snapshot = try decoder.decode(ClusterState.self, from: data)\n            latestSnapshot = snapshot\n            if modelOptions.isEmpty {\n                Task { await fetchModels() }\n            }\n            lastError = nil\n        } catch {\n            lastError = error.localizedDescription\n        }\n    }\n\n    func deleteInstance(_ id: String) async {\n        do {\n            var request = URLRequest(url: baseURL.appendingPathComponent(\"instance/\\(id)\"))\n            request.httpMethod = \"DELETE\"\n            request.setValue(\"application/json\", forHTTPHeaderField: \"Accept\")\n            let (_, response) = try await session.data(for: request)\n            guard let httpResponse = response as? HTTPURLResponse else {\n                throw URLError(.badServerResponse)\n            }\n            guard (200..<300).contains(httpResponse.statusCode) else {\n                throw URLError(.badServerResponse)\n            }\n            lastActionMessage = \"Instance deleted\"\n            await fetchSnapshot()\n        } catch {\n            lastError = \"Failed to delete instance: \\(error.localizedDescription)\"\n        }\n    }\n\n    func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int)\n        async\n    {\n        do {\n            var request = URLRequest(url: baseURL.appendingPathComponent(\"instance\"))\n            request.httpMethod = \"POST\"\n            request.setValue(\"application/json\", forHTTPHeaderField: \"Content-Type\")\n            let payload: [String: Any] = [\n                \"model_id\": modelId,\n                \"sharding\": sharding,\n                \"instance_meta\": instanceMeta,\n                \"min_nodes\": minNodes,\n            ]\n            request.httpBody = try JSONSerialization.data(withJSONObject: payload, options: [])\n            let (_, response) = try await session.data(for: request)\n            guard let httpResponse = response as? HTTPURLResponse else {\n                throw URLError(.badServerResponse)\n            }\n            guard (200..<300).contains(httpResponse.statusCode) else {\n                throw URLError(.badServerResponse)\n            }\n            lastActionMessage = \"Instance launched\"\n            await fetchSnapshot()\n        } catch {\n            lastError = \"Failed to launch instance: \\(error.localizedDescription)\"\n        }\n    }\n\n    func fetchModels() async {\n        do {\n            let url = baseURL.appendingPathComponent(\"models\")\n            let (data, response) = try await session.data(from: url)\n            guard let httpResponse = response as? HTTPURLResponse,\n                (200..<300).contains(httpResponse.statusCode)\n            else {\n                throw URLError(.badServerResponse)\n            }\n            let list = try decoder.decode(ModelListResponse.self, from: data)\n            modelOptions = list.data.map { ModelOption(id: $0.id, displayName: $0.name ?? $0.id) }\n        } catch {\n            lastError = \"Failed to load models: \\(error.localizedDescription)\"\n        }\n    }\n}\n\nstruct ModelOption: Identifiable {\n    let id: String\n    let displayName: String\n}\n\nstruct ModelListResponse: Decodable {\n    let data: [ModelListModel]\n}\n\nstruct ModelListModel: Decodable {\n    let id: String\n    let name: String?\n}\n"
  },
  {
    "path": "app/EXO/EXO/Services/LocalNetworkChecker.swift",
    "content": "import Foundation\nimport Network\nimport os.log\n\n/// Checks if the app's local network permission is actually functional.\n///\n/// macOS local network permission can appear enabled in System Preferences but not\n/// actually work after a restart. This service uses NWConnection to mDNS multicast\n/// to verify actual connectivity.\n@MainActor\nfinal class LocalNetworkChecker: ObservableObject {\n    enum Status: Equatable {\n        case unknown\n        case checking\n        case working\n        case notWorking(reason: String)\n\n        var isHealthy: Bool {\n            if case .working = self { return true }\n            return false\n        }\n\n        var displayText: String {\n            switch self {\n            case .unknown:\n                return \"Unknown\"\n            case .checking:\n                return \"Checking...\"\n            case .working:\n                return \"Working\"\n            case .notWorking(let reason):\n                return reason\n            }\n        }\n    }\n\n    private static let logger = Logger(subsystem: \"io.exo.EXO\", category: \"LocalNetworkChecker\")\n    private static let hasCompletedInitialCheckKey = \"LocalNetworkChecker.hasCompletedInitialCheck\"\n\n    @Published private(set) var status: Status = .unknown\n\n    private var connection: NWConnection?\n    private var checkTask: Task<Void, Never>?\n    private var periodicTask: Task<Void, Never>?\n\n    /// Whether we've completed at least one check (stored in UserDefaults)\n    private var hasCompletedInitialCheck: Bool {\n        get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }\n        set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }\n    }\n\n    /// Checks if local network access is working (one-time check).\n    func check() {\n        performCheck()\n    }\n\n    /// Starts periodic checking of local network access.\n    /// Re-checks every `interval` seconds so the warning disappears when user grants permission.\n    func startPeriodicChecking(interval: TimeInterval = 10) {\n        stopPeriodicChecking()\n        // Do an immediate check first\n        performCheck()\n        // Then schedule periodic checks\n        periodicTask = Task { [weak self] in\n            while !Task.isCancelled {\n                try? await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000))\n                guard !Task.isCancelled else { break }\n                self?.performCheck()\n            }\n        }\n    }\n\n    /// Stops periodic checking.\n    func stopPeriodicChecking() {\n        periodicTask?.cancel()\n        periodicTask = nil\n    }\n\n    private func performCheck() {\n        checkTask?.cancel()\n        // Only show \"checking\" status on first check to avoid UI flicker\n        if status == .unknown {\n            status = .checking\n        }\n\n        // Use longer timeout on first launch to allow time for permission prompt\n        let isFirstCheck = !hasCompletedInitialCheck\n        let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000\n\n        checkTask = Task { [weak self] in\n            guard let self else { return }\n\n            Self.logger.debug(\"Checking local network connectivity (first check: \\(isFirstCheck))\")\n            let result = await self.checkConnectivity(timeout: timeout)\n            self.status = result\n            self.hasCompletedInitialCheck = true\n\n            // Only log on state changes or first check to reduce noise\n            if isFirstCheck || result != self.status {\n                Self.logger.info(\"Local network check: \\(result.displayText)\")\n            }\n        }\n    }\n\n    /// Checks connectivity using NWConnection to mDNS multicast.\n    /// The connection attempt triggers the permission prompt if not yet shown.\n    private func checkConnectivity(timeout: UInt64) async -> Status {\n        connection?.cancel()\n        connection = nil\n\n        // mDNS multicast address - same as libp2p uses for peer discovery\n        let host = NWEndpoint.Host(\"224.0.0.251\")\n        let port = NWEndpoint.Port(integerLiteral: 5353)\n\n        let params = NWParameters.udp\n        params.allowLocalEndpointReuse = true\n\n        let conn = NWConnection(host: host, port: port, using: params)\n        connection = conn\n\n        return await withCheckedContinuation { continuation in\n            var hasResumed = false\n            let lock = NSLock()\n\n            let resumeOnce: (Status) -> Void = { status in\n                lock.lock()\n                defer { lock.unlock() }\n                guard !hasResumed else { return }\n                hasResumed = true\n                continuation.resume(returning: status)\n            }\n\n            conn.stateUpdateHandler = { state in\n                switch state {\n                case .ready:\n                    resumeOnce(.working)\n                case .waiting(let error):\n                    let errorStr = \"\\(error)\"\n                    if errorStr.contains(\"54\") || errorStr.contains(\"ECONNRESET\") {\n                        resumeOnce(.notWorking(reason: \"Connection blocked\"))\n                    }\n                // Otherwise keep waiting - might be showing permission prompt\n                case .failed(let error):\n                    let errorStr = \"\\(error)\"\n                    if errorStr.contains(\"65\") || errorStr.contains(\"EHOSTUNREACH\")\n                        || errorStr.contains(\"permission\") || errorStr.contains(\"denied\")\n                    {\n                        resumeOnce(.notWorking(reason: \"Permission denied\"))\n                    } else {\n                        resumeOnce(.notWorking(reason: \"Failed: \\(error.localizedDescription)\"))\n                    }\n                case .cancelled, .setup, .preparing:\n                    break\n                @unknown default:\n                    break\n                }\n            }\n\n            conn.start(queue: .main)\n\n            Task {\n                try? await Task.sleep(nanoseconds: timeout)\n                let state = conn.state\n                switch state {\n                case .ready:\n                    resumeOnce(.working)\n                case .waiting, .preparing, .setup:\n                    resumeOnce(.notWorking(reason: \"Timeout (may be blocked)\"))\n                default:\n                    resumeOnce(.notWorking(reason: \"Timeout\"))\n                }\n            }\n        }\n    }\n\n    func stop() {\n        stopPeriodicChecking()\n        checkTask?.cancel()\n        checkTask = nil\n        connection?.cancel()\n        connection = nil\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Services/NetworkSetupHelper.swift",
    "content": "import AppKit\nimport Foundation\nimport os.log\n\nenum NetworkSetupHelper {\n    private static let logger = Logger(subsystem: \"io.exo.EXO\", category: \"NetworkSetup\")\n    private static let daemonLabel = \"io.exo.networksetup\"\n    private static let scriptDestination =\n        \"/Library/Application Support/EXO/disable_bridge.sh\"\n    // Legacy script path from older versions\n    private static let legacyScriptDestination =\n        \"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh\"\n    private static let plistDestination = \"/Library/LaunchDaemons/io.exo.networksetup.plist\"\n    private static let requiredStartInterval: Int = 1786\n\n    private static let setupScript = \"\"\"\n        #!/usr/bin/env bash\n\n        set -euo pipefail\n\n        # Wait for macOS to finish network setup after boot\n        sleep 20\n\n        PREFS=\"/Library/Preferences/SystemConfiguration/preferences.plist\"\n\n        # Remove bridge0 interface\n        ifconfig bridge0 &>/dev/null && {\n          ifconfig bridge0 | grep -q 'member' && {\n            ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true\n          }\n          ifconfig bridge0 destroy 2>/dev/null || true\n        }\n\n        # Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist\n        /usr/libexec/PlistBuddy -c \"Delete :VirtualNetworkInterfaces:Bridge:bridge0\" \"$PREFS\" 2>/dev/null || true\n\n        networksetup -listlocations | grep -q exo || {\n          networksetup -createlocation exo\n        }\n\n        networksetup -switchtolocation exo\n        networksetup -listallhardwareports \\\\\n          | awk -F': ' '/Hardware Port: / {print $2}' \\\\\n          | while IFS=\":\" read -r name; do\n              case \"$name\" in\n                \"Ethernet Adapter\"*)\n                        ;;\n                \"Thunderbolt Bridge\")\n                        ;;\n                \"Thunderbolt \"*)\n                  networksetup -listallnetworkservices \\\\\n                    | grep -q \"EXO $name\" \\\\\n                      || networksetup -createnetworkservice \"EXO $name\" \"$name\" 2>/dev/null \\\\\n                      || continue\n                  networksetup -setdhcp \"EXO $name\"\n                        ;;\n                *)\n                  networksetup -listallnetworkservices \\\\\n                    | grep -q \"$name\" \\\\\n                      || networksetup -createnetworkservice \"$name\" \"$name\" 2>/dev/null \\\\\n                      || continue\n                        ;;\n              esac\n            done\n\n        networksetup -listnetworkservices | grep -q \"Thunderbolt Bridge\" && {\n          networksetup -setnetworkserviceenabled \"Thunderbolt Bridge\" off\n        } || true\n        \"\"\"\n\n    /// Prompts user and installs the LaunchDaemon if not already installed.\n    /// Shows an alert explaining what will be installed before requesting admin privileges.\n    static func promptAndInstallIfNeeded() {\n        // Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion\n        Task.detached(priority: .utility) {\n            // If already correctly installed, skip\n            if daemonAlreadyInstalled() {\n                return\n            }\n\n            // Show alert on main thread\n            let shouldInstall = await MainActor.run {\n                let alert = NSAlert()\n                alert.messageText = \"EXO Network Configuration\"\n                alert.informativeText =\n                    \"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\\n\\nYou will be prompted for your password.\"\n                alert.alertStyle = .informational\n                alert.addButton(withTitle: \"Install\")\n                alert.addButton(withTitle: \"Not Now\")\n                return alert.runModal() == .alertFirstButtonReturn\n            }\n\n            guard shouldInstall else {\n                logger.info(\"User deferred network setup daemon installation\")\n                return\n            }\n\n            do {\n                try installLaunchDaemon()\n                logger.info(\"Network setup launch daemon installed and started\")\n            } catch {\n                logger.error(\n                    \"Network setup launch daemon failed: \\(error.localizedDescription, privacy: .public)\"\n                )\n            }\n        }\n    }\n\n    /// Removes all EXO network setup components from the system.\n    /// This includes the LaunchDaemon, scripts, logs, and network location.\n    /// Requires admin privileges.\n    static func uninstall() throws {\n        let uninstallScript = makeUninstallScript()\n        try runShellAsAdmin(uninstallScript)\n        logger.info(\"EXO network setup components removed successfully\")\n    }\n\n    /// Checks if there are any EXO network components installed that need cleanup\n    static func hasInstalledComponents() -> Bool {\n        let manager = FileManager.default\n        let scriptExists = manager.fileExists(atPath: scriptDestination)\n        let legacyScriptExists = manager.fileExists(atPath: legacyScriptDestination)\n        let plistExists = manager.fileExists(atPath: plistDestination)\n        return scriptExists || legacyScriptExists || plistExists\n    }\n\n    private static func daemonAlreadyInstalled() -> Bool {\n        let manager = FileManager.default\n        let scriptExists = manager.fileExists(atPath: scriptDestination)\n        let plistExists = manager.fileExists(atPath: plistDestination)\n        guard scriptExists, plistExists else { return false }\n        guard\n            let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),\n            installedScript.trimmingCharacters(in: .whitespacesAndNewlines)\n                == setupScript.trimmingCharacters(in: .whitespacesAndNewlines)\n        else {\n            return false\n        }\n        guard\n            let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),\n            let plist = try? PropertyListSerialization.propertyList(\n                from: data, options: [], format: nil) as? [String: Any]\n        else {\n            return false\n        }\n        guard\n            let interval = plist[\"StartInterval\"] as? Int,\n            interval == requiredStartInterval\n        else {\n            return false\n        }\n        if let programArgs = plist[\"ProgramArguments\"] as? [String],\n            programArgs.contains(scriptDestination) == false\n        {\n            return false\n        }\n        return true\n    }\n\n    private static func installLaunchDaemon() throws {\n        let installerScript = makeInstallerScript()\n        try runShellAsAdmin(installerScript)\n    }\n\n    private static func makeInstallerScript() -> String {\n        \"\"\"\n        set -euo pipefail\n\n        LABEL=\"\\(daemonLabel)\"\n        SCRIPT_DEST=\"\\(scriptDestination)\"\n        LEGACY_SCRIPT_DEST=\"\\(legacyScriptDestination)\"\n        PLIST_DEST=\"\\(plistDestination)\"\n        LOG_OUT=\"/var/log/\\(daemonLabel).log\"\n        LOG_ERR=\"/var/log/\\(daemonLabel).err.log\"\n\n        # First, completely remove any existing installation\n        launchctl bootout system/\"$LABEL\" 2>/dev/null || true\n        rm -f \"$PLIST_DEST\"\n        rm -f \"$SCRIPT_DEST\"\n        rm -f \"$LEGACY_SCRIPT_DEST\"\n        rm -f \"$LOG_OUT\" \"$LOG_ERR\"\n\n        # Install fresh\n        mkdir -p \"$(dirname \"$SCRIPT_DEST\")\"\n\n        cat > \"$SCRIPT_DEST\" <<'EOF_SCRIPT'\n        \\(setupScript)\n        EOF_SCRIPT\n        chmod 755 \"$SCRIPT_DEST\"\n\n        cat > \"$PLIST_DEST\" <<'EOF_PLIST'\n        <?xml version=\"1.0\" encoding=\"UTF-8\"?>\n        <!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">\n        <plist version=\"1.0\">\n        <dict>\n          <key>Label</key>\n          <string>\\(daemonLabel)</string>\n          <key>ProgramArguments</key>\n          <array>\n            <string>/bin/bash</string>\n            <string>\\(scriptDestination)</string>\n          </array>\n          <key>StartInterval</key>\n          <integer>\\(requiredStartInterval)</integer>\n          <key>RunAtLoad</key>\n          <true/>\n          <key>StandardOutPath</key>\n          <string>/var/log/\\(daemonLabel).log</string>\n          <key>StandardErrorPath</key>\n          <string>/var/log/\\(daemonLabel).err.log</string>\n        </dict>\n        </plist>\n        EOF_PLIST\n\n        launchctl bootstrap system \"$PLIST_DEST\"\n        launchctl enable system/\"$LABEL\"\n        launchctl kickstart -k system/\"$LABEL\"\n        \"\"\"\n    }\n\n    private static func makeUninstallScript() -> String {\n        \"\"\"\n        set -euo pipefail\n\n        LABEL=\"\\(daemonLabel)\"\n        SCRIPT_DEST=\"\\(scriptDestination)\"\n        LEGACY_SCRIPT_DEST=\"\\(legacyScriptDestination)\"\n        PLIST_DEST=\"\\(plistDestination)\"\n        LOG_OUT=\"/var/log/\\(daemonLabel).log\"\n        LOG_ERR=\"/var/log/\\(daemonLabel).err.log\"\n\n        # Unload the LaunchDaemon if running\n        launchctl bootout system/\"$LABEL\" 2>/dev/null || true\n\n        # Remove LaunchDaemon plist\n        rm -f \"$PLIST_DEST\"\n\n        # Remove the script (current and legacy paths) and parent directory if empty\n        rm -f \"$SCRIPT_DEST\"\n        rm -f \"$LEGACY_SCRIPT_DEST\"\n        rmdir \"$(dirname \"$SCRIPT_DEST\")\" 2>/dev/null || true\n\n        # Remove log files\n        rm -f \"$LOG_OUT\" \"$LOG_ERR\"\n\n        # Switch back to Automatic network location\n        networksetup -switchtolocation Automatic >/dev/null 2>&1 || true\n\n        # Delete the exo network location if it exists\n        networksetup -listlocations 2>/dev/null | grep -q '^exo$' && {\n          networksetup -deletelocation exo >/dev/null 2>&1 || true\n        } || true\n\n        # Re-enable any Thunderbolt Bridge service if it exists\n        # We find it dynamically by looking for bridges containing Thunderbolt interfaces\n        find_and_enable_thunderbolt_bridge() {\n          # Get Thunderbolt interface devices from hardware ports\n          tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '\n            /^Hardware Port:/ { port = tolower(substr($0, 16)) }\n            /^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }\n          ') || true\n          [ -z \"$tb_devices\" ] && return 0\n\n          # For each bridge device, check if it contains Thunderbolt interfaces\n          for bridge in bridge0 bridge1 bridge2; do\n            members=$(ifconfig \"$bridge\" 2>/dev/null | awk '/member:/ {print $2}') || true\n            [ -z \"$members\" ] && continue\n\n            for tb_dev in $tb_devices; do\n              if echo \"$members\" | grep -qx \"$tb_dev\"; then\n                # Find the service name for this bridge device\n                service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev=\"$bridge\" '\n                  /^\\\\([0-9*]/ { gsub(/^\\\\([0-9*]+\\\\) /, \"\"); svc = $0 }\n                  /Device:/ && $0 ~ dev { print svc; exit }\n                ') || true\n                if [ -n \"$service_name\" ]; then\n                  networksetup -setnetworkserviceenabled \"$service_name\" on 2>/dev/null || true\n                  return 0\n                fi\n              fi\n            done\n          done\n          return 0\n        }\n        find_and_enable_thunderbolt_bridge || true\n\n        echo \"EXO network components removed successfully\"\n        \"\"\"\n    }\n\n    /// Direct install without GUI (requires root).\n    /// Returns true on success, false on failure.\n    static func installDirectly() -> Bool {\n        let script = makeInstallerScript()\n        return runShellDirectly(script)\n    }\n\n    /// Direct uninstall without GUI (requires root).\n    /// Returns true on success, false on failure.\n    static func uninstallDirectly() -> Bool {\n        let script = makeUninstallScript()\n        return runShellDirectly(script)\n    }\n\n    /// Run a shell script directly via Process (no AppleScript, requires root).\n    /// Returns true on success, false on failure.\n    private static func runShellDirectly(_ script: String) -> Bool {\n        let process = Process()\n        process.executableURL = URL(fileURLWithPath: \"/bin/bash\")\n        process.arguments = [\"-c\", script]\n\n        let outputPipe = Pipe()\n        let errorPipe = Pipe()\n        process.standardOutput = outputPipe\n        process.standardError = errorPipe\n\n        do {\n            try process.run()\n            process.waitUntilExit()\n\n            let outputData = outputPipe.fileHandleForReading.readDataToEndOfFile()\n            let errorData = errorPipe.fileHandleForReading.readDataToEndOfFile()\n\n            if let output = String(data: outputData, encoding: .utf8), !output.isEmpty {\n                print(output)\n            }\n            if let errorOutput = String(data: errorData, encoding: .utf8), !errorOutput.isEmpty {\n                fputs(errorOutput, stderr)\n            }\n\n            if process.terminationStatus == 0 {\n                logger.info(\"Shell script completed successfully\")\n                return true\n            } else {\n                logger.error(\"Shell script failed with exit code \\(process.terminationStatus)\")\n                return false\n            }\n        } catch {\n            logger.error(\n                \"Failed to run shell script: \\(error.localizedDescription, privacy: .public)\")\n            fputs(\"Error: \\(error.localizedDescription)\\n\", stderr)\n            return false\n        }\n    }\n\n    private static func runShellAsAdmin(_ script: String) throws {\n        let escapedScript =\n            script\n            .replacingOccurrences(of: \"\\\\\", with: \"\\\\\\\\\")\n            .replacingOccurrences(of: \"\\\"\", with: \"\\\\\\\"\")\n\n        let appleScriptSource = \"\"\"\n            do shell script \"\\(escapedScript)\" with administrator privileges\n            \"\"\"\n\n        guard let appleScript = NSAppleScript(source: appleScriptSource) else {\n            throw NetworkSetupError.scriptCreationFailed\n        }\n\n        var errorInfo: NSDictionary?\n        appleScript.executeAndReturnError(&errorInfo)\n\n        if let errorInfo {\n            let message = errorInfo[NSAppleScript.errorMessage] as? String ?? \"Unknown error\"\n            throw NetworkSetupError.executionFailed(message)\n        }\n    }\n}\n\nenum NetworkSetupError: LocalizedError {\n    case scriptCreationFailed\n    case executionFailed(String)\n\n    var errorDescription: String? {\n        switch self {\n        case .scriptCreationFailed:\n            return \"Failed to create AppleScript for network setup\"\n        case .executionFailed(let message):\n            return \"Network setup script failed: \\(message)\"\n        }\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Services/NetworkStatusService.swift",
    "content": "import AppKit\nimport Foundation\n\n@MainActor\nfinal class NetworkStatusService: ObservableObject {\n    @Published private(set) var status: NetworkStatus = .empty\n    private var timer: Timer?\n\n    func refresh() async {\n        let fetched = await Task.detached(priority: .background) {\n            NetworkStatusFetcher().fetch()\n        }.value\n        status = fetched\n    }\n\n    func startPolling(interval: TimeInterval = 30) {\n        timer?.invalidate()\n        timer = Timer.scheduledTimer(withTimeInterval: interval, repeats: true) { [weak self] _ in\n            guard let self else { return }\n            Task { await self.refresh() }\n        }\n        if let timer {\n            RunLoop.main.add(timer, forMode: .common)\n        }\n        Task { await refresh() }\n    }\n\n    func stopPolling() {\n        timer?.invalidate()\n        timer = nil\n    }\n}\n\nstruct NetworkStatus: Equatable {\n    let thunderboltBridgeState: ThunderboltState?\n    let bridgeInactive: Bool?\n    let interfaceStatuses: [InterfaceIpStatus]\n    let localRdmaDevices: [String]\n    let localRdmaActivePorts: [RDMAPort]\n\n    static let empty = NetworkStatus(\n        thunderboltBridgeState: nil,\n        bridgeInactive: nil,\n        interfaceStatuses: [],\n        localRdmaDevices: [],\n        localRdmaActivePorts: []\n    )\n}\n\nstruct RDMAPort: Equatable {\n    let device: String\n    let port: String\n    let state: String\n}\n\nstruct InterfaceIpStatus: Equatable {\n    let interfaceName: String\n    let ipAddress: String?\n}\n\nenum ThunderboltState: Equatable {\n    case enabled\n    case disabled\n    case deleted\n}\n\nprivate struct NetworkStatusFetcher {\n    func fetch() -> NetworkStatus {\n        NetworkStatus(\n            thunderboltBridgeState: readThunderboltBridgeState(),\n            bridgeInactive: readBridgeInactive(),\n            interfaceStatuses: readInterfaceStatuses(),\n            localRdmaDevices: readRDMADevices(),\n            localRdmaActivePorts: readRDMAActivePorts()\n        )\n    }\n\n    private func readRDMADevices() -> [String] {\n        let result = runCommand([\"ibv_devices\"])\n        guard result.exitCode == 0 else { return [] }\n        var devices: [String] = []\n        for line in result.output.split(separator: \"\\n\") {\n            let trimmed = line.trimmingCharacters(in: .whitespaces)\n            if trimmed.hasPrefix(\"---\") || trimmed.lowercased().hasPrefix(\"device\")\n                || trimmed.isEmpty\n            {\n                continue\n            }\n            let parts = trimmed.split(separator: \" \", maxSplits: 1)\n            if let deviceName = parts.first {\n                devices.append(String(deviceName))\n            }\n        }\n        return devices\n    }\n\n    private func readRDMAActivePorts() -> [RDMAPort] {\n        let result = runCommand([\"ibv_devinfo\"])\n        guard result.exitCode == 0 else { return [] }\n        var ports: [RDMAPort] = []\n        var currentDevice: String?\n        var currentPort: String?\n\n        for line in result.output.split(separator: \"\\n\") {\n            let trimmed = line.trimmingCharacters(in: .whitespaces)\n            if trimmed.hasPrefix(\"hca_id:\") {\n                currentDevice = trimmed.replacingOccurrences(of: \"hca_id:\", with: \"\")\n                    .trimmingCharacters(in: .whitespaces)\n            } else if trimmed.hasPrefix(\"port:\") {\n                currentPort = trimmed.replacingOccurrences(of: \"port:\", with: \"\")\n                    .trimmingCharacters(in: .whitespaces)\n            } else if trimmed.hasPrefix(\"state:\") {\n                let state = trimmed.replacingOccurrences(of: \"state:\", with: \"\").trimmingCharacters(\n                    in: .whitespaces)\n                if let device = currentDevice, let port = currentPort {\n                    if state.lowercased().contains(\"active\") {\n                        ports.append(RDMAPort(device: device, port: port, state: state))\n                    }\n                }\n            }\n        }\n        return ports\n    }\n\n    private func readThunderboltBridgeState() -> ThunderboltState? {\n        // Dynamically find the Thunderbolt Bridge service (don't assume the name)\n        guard let serviceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName() else {\n            // No bridge containing Thunderbolt interfaces exists\n            return .deleted\n        }\n\n        guard let isEnabled = ThunderboltBridgeDetector.isServiceEnabled(serviceName: serviceName)\n        else {\n            return nil\n        }\n\n        return isEnabled ? .enabled : .disabled\n    }\n\n    private func readBridgeInactive() -> Bool? {\n        let result = runCommand([\"ifconfig\", \"bridge0\"])\n        guard result.exitCode == 0 else { return nil }\n        guard\n            let statusLine = result.output\n                .components(separatedBy: .newlines)\n                .first(where: { $0.contains(\"status:\") })?\n                .lowercased()\n        else {\n            return nil\n        }\n        if statusLine.contains(\"inactive\") {\n            return true\n        }\n        if statusLine.contains(\"active\") {\n            return false\n        }\n        return nil\n    }\n\n    private func readInterfaceStatuses() -> [InterfaceIpStatus] {\n        (0...7).map { \"en\\($0)\" }.map(readInterfaceStatus)\n    }\n\n    private func readInterfaceStatus(for interface: String) -> InterfaceIpStatus {\n        let result = runCommand([\"ifconfig\", interface])\n        guard result.exitCode == 0 else {\n            return InterfaceIpStatus(\n                interfaceName: interface,\n                ipAddress: nil\n            )\n        }\n\n        let output = result.output\n        let ip = firstInet(from: output)\n\n        return InterfaceIpStatus(\n            interfaceName: interface,\n            ipAddress: ip\n        )\n    }\n\n    private func firstInet(from ifconfigOutput: String) -> String? {\n        for line in ifconfigOutput.split(separator: \"\\n\") {\n            let trimmed = line.trimmingCharacters(in: .whitespaces)\n            guard trimmed.hasPrefix(\"inet \") else { continue }\n            let parts = trimmed.split(separator: \" \")\n            if parts.count >= 2 {\n                let candidate = String(parts[1])\n                if candidate != \"127.0.0.1\" {\n                    return candidate\n                }\n            }\n        }\n        return nil\n    }\n\n    private struct CommandResult {\n        let exitCode: Int32\n        let output: String\n        let error: String\n    }\n\n    private func runCommand(_ arguments: [String]) -> CommandResult {\n        let process = Process()\n        process.launchPath = \"/usr/bin/env\"\n        process.arguments = arguments\n\n        let stdout = Pipe()\n        let stderr = Pipe()\n        process.standardOutput = stdout\n        process.standardError = stderr\n\n        do {\n            try process.run()\n        } catch {\n            return CommandResult(exitCode: -1, output: \"\", error: error.localizedDescription)\n        }\n        process.waitUntilExit()\n\n        let outputData = stdout.fileHandleForReading.readDataToEndOfFile()\n        let errorData = stderr.fileHandleForReading.readDataToEndOfFile()\n\n        return CommandResult(\n            exitCode: process.terminationStatus,\n            output: String(decoding: outputData, as: UTF8.self),\n            error: String(decoding: errorData, as: UTF8.self)\n        )\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Services/ThunderboltBridgeDetector.swift",
    "content": "import Foundation\nimport os.log\n\n/// Utility for dynamically detecting Thunderbolt Bridge network services.\n/// This mirrors the Python logic in info_gatherer.py - we never assume the service\n/// is named \"Thunderbolt Bridge\", instead we find bridges containing Thunderbolt interfaces.\nenum ThunderboltBridgeDetector {\n    private static let logger = Logger(\n        subsystem: \"io.exo.EXO\", category: \"ThunderboltBridgeDetector\")\n\n    struct CommandResult {\n        let exitCode: Int32\n        let output: String\n        let error: String\n    }\n\n    /// Find the network service name of a bridge containing Thunderbolt interfaces.\n    /// Returns nil if no such bridge exists.\n    static func findThunderboltBridgeServiceName() -> String? {\n        // 1. Get all Thunderbolt interface devices (e.g., en2, en3)\n        guard let thunderboltDevices = getThunderboltDevices(), !thunderboltDevices.isEmpty else {\n            logger.debug(\"No Thunderbolt devices found\")\n            return nil\n        }\n        logger.debug(\"Found Thunderbolt devices: \\(thunderboltDevices.joined(separator: \", \"))\")\n\n        // 2. Get bridge services from network service order\n        guard let bridgeServices = getBridgeServices(), !bridgeServices.isEmpty else {\n            logger.debug(\"No bridge services found\")\n            return nil\n        }\n        logger.debug(\"Found bridge services: \\(bridgeServices.keys.joined(separator: \", \"))\")\n\n        // 3. Find a bridge that contains Thunderbolt interfaces\n        for (bridgeDevice, serviceName) in bridgeServices {\n            let members = getBridgeMembers(bridgeDevice: bridgeDevice)\n            logger.debug(\n                \"Bridge \\(bridgeDevice) (\\(serviceName)) has members: \\(members.joined(separator: \", \"))\"\n            )\n\n            // Check if any Thunderbolt device is a member of this bridge\n            if !members.isDisjoint(with: thunderboltDevices) {\n                logger.info(\n                    \"Found Thunderbolt Bridge service: '\\(serviceName)' (device: \\(bridgeDevice))\")\n                return serviceName\n            }\n        }\n\n        logger.debug(\"No bridge found containing Thunderbolt interfaces\")\n        return nil\n    }\n\n    /// Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.\n    private static func getThunderboltDevices() -> Set<String>? {\n        let result = runCommand([\"networksetup\", \"-listallhardwareports\"])\n        guard result.exitCode == 0 else {\n            logger.warning(\"networksetup -listallhardwareports failed: \\(result.error)\")\n            return nil\n        }\n\n        var thunderboltDevices: Set<String> = []\n        var currentPort: String?\n\n        for line in result.output.components(separatedBy: .newlines) {\n            let trimmed = line.trimmingCharacters(in: .whitespaces)\n            if trimmed.hasPrefix(\"Hardware Port:\") {\n                currentPort = String(trimmed.dropFirst(\"Hardware Port:\".count)).trimmingCharacters(\n                    in: .whitespaces)\n            } else if trimmed.hasPrefix(\"Device:\"), let port = currentPort {\n                let device = String(trimmed.dropFirst(\"Device:\".count)).trimmingCharacters(\n                    in: .whitespaces)\n                if port.lowercased().contains(\"thunderbolt\") {\n                    thunderboltDevices.insert(device)\n                }\n                currentPort = nil\n            }\n        }\n\n        return thunderboltDevices\n    }\n\n    /// Get mapping of bridge device -> service name from network service order.\n    private static func getBridgeServices() -> [String: String]? {\n        let result = runCommand([\"networksetup\", \"-listnetworkserviceorder\"])\n        guard result.exitCode == 0 else {\n            logger.warning(\"networksetup -listnetworkserviceorder failed: \\(result.error)\")\n            return nil\n        }\n\n        // Parse service order to find bridge devices and their service names\n        // Format: \"(1) Service Name\\n(Hardware Port: ..., Device: bridge0)\\n\"\n        var bridgeServices: [String: String] = [:]\n        var currentService: String?\n\n        for line in result.output.components(separatedBy: .newlines) {\n            let trimmed = line.trimmingCharacters(in: .whitespaces)\n\n            // Match \"(N) Service Name\" or \"(*) Service Name\" (disabled)\n            // but NOT \"(Hardware Port: ...)\" lines\n            if trimmed.hasPrefix(\"(\"), trimmed.contains(\")\"),\n                !trimmed.hasPrefix(\"(Hardware Port:\")\n            {\n                if let parenEnd = trimmed.firstIndex(of: \")\") {\n                    let afterParen = trimmed.index(after: parenEnd)\n                    if afterParen < trimmed.endIndex {\n                        currentService =\n                            String(trimmed[afterParen...])\n                            .trimmingCharacters(in: .whitespaces)\n                    }\n                }\n            }\n            // Match \"(Hardware Port: ..., Device: bridgeX)\"\n            else if let service = currentService, trimmed.contains(\"Device: bridge\") {\n                // Extract device name from \"..., Device: bridge0)\"\n                if let deviceRange = trimmed.range(of: \"Device: \") {\n                    let afterDevice = trimmed[deviceRange.upperBound...]\n                    if let parenIndex = afterDevice.firstIndex(of: \")\") {\n                        let device = String(afterDevice[..<parenIndex])\n                        bridgeServices[device] = service\n                    }\n                }\n            }\n        }\n\n        return bridgeServices\n    }\n\n    /// Get member interfaces of a bridge device via ifconfig.\n    private static func getBridgeMembers(bridgeDevice: String) -> Set<String> {\n        let result = runCommand([\"ifconfig\", bridgeDevice])\n        guard result.exitCode == 0 else {\n            logger.debug(\"ifconfig \\(bridgeDevice) failed\")\n            return []\n        }\n\n        var members: Set<String> = []\n        for line in result.output.components(separatedBy: .newlines) {\n            let trimmed = line.trimmingCharacters(in: .whitespaces)\n            if trimmed.hasPrefix(\"member:\") {\n                let parts = trimmed.split(separator: \" \")\n                if parts.count > 1 {\n                    members.insert(String(parts[1]))\n                }\n            }\n        }\n\n        return members\n    }\n\n    /// Check if a network service is enabled.\n    static func isServiceEnabled(serviceName: String) -> Bool? {\n        let result = runCommand([\"networksetup\", \"-getnetworkserviceenabled\", serviceName])\n        guard result.exitCode == 0 else {\n            logger.warning(\"Failed to check if '\\(serviceName)' is enabled: \\(result.error)\")\n            return nil\n        }\n\n        let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)\n        if output.contains(\"enabled\") {\n            return true\n        }\n        if output.contains(\"disabled\") {\n            return false\n        }\n        return nil\n    }\n\n    private static func runCommand(_ arguments: [String]) -> CommandResult {\n        let process = Process()\n        process.launchPath = \"/usr/bin/env\"\n        process.arguments = arguments\n\n        let stdout = Pipe()\n        let stderr = Pipe()\n        process.standardOutput = stdout\n        process.standardError = stderr\n\n        do {\n            try process.run()\n        } catch {\n            return CommandResult(exitCode: -1, output: \"\", error: error.localizedDescription)\n        }\n        process.waitUntilExit()\n\n        let outputData = stdout.fileHandleForReading.readDataToEndOfFile()\n        let errorData = stderr.fileHandleForReading.readDataToEndOfFile()\n\n        return CommandResult(\n            exitCode: process.terminationStatus,\n            output: String(decoding: outputData, as: UTF8.self),\n            error: String(decoding: errorData, as: UTF8.self)\n        )\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Services/ThunderboltBridgeService.swift",
    "content": "import AppKit\nimport Combine\nimport Foundation\nimport Security\nimport SystemConfiguration\nimport os.log\n\n@MainActor\nfinal class ThunderboltBridgeService: ObservableObject {\n    private static let logger = Logger(subsystem: \"io.exo.EXO\", category: \"ThunderboltBridge\")\n\n    @Published private(set) var detectedCycle: [String]?\n    @Published private(set) var hasPromptedForCurrentCycle = false\n    @Published private(set) var lastError: String?\n\n    private weak var clusterStateService: ClusterStateService?\n    private var cancellables = Set<AnyCancellable>()\n    private var previousCycleSignature: String?\n\n    init(clusterStateService: ClusterStateService) {\n        self.clusterStateService = clusterStateService\n        setupObserver()\n    }\n\n    private func setupObserver() {\n        guard let service = clusterStateService else { return }\n\n        service.$latestSnapshot\n            .compactMap { $0 }\n            .sink { [weak self] snapshot in\n                self?.checkForCycles(snapshot: snapshot)\n            }\n            .store(in: &cancellables)\n    }\n\n    private func checkForCycles(snapshot: ClusterState) {\n        let cycles = snapshot.thunderboltBridgeCycles\n\n        // Only consider cycles with more than 2 nodes\n        guard let firstCycle = cycles.first, firstCycle.count > 2 else {\n            // No problematic cycles detected, reset state\n            if detectedCycle != nil {\n                detectedCycle = nil\n                previousCycleSignature = nil\n                hasPromptedForCurrentCycle = false\n            }\n            return\n        }\n\n        // Create a signature for this cycle to detect if it changed\n        let cycleSignature = firstCycle.sorted().joined(separator: \",\")\n\n        // If this is a new/different cycle, reset the prompt state\n        if cycleSignature != previousCycleSignature {\n            previousCycleSignature = cycleSignature\n            hasPromptedForCurrentCycle = false\n        }\n\n        detectedCycle = firstCycle\n\n        // Only prompt once per cycle\n        if !hasPromptedForCurrentCycle {\n            showDisableBridgePrompt(nodeIds: firstCycle)\n        }\n    }\n\n    private func showDisableBridgePrompt(nodeIds: [String]) {\n        hasPromptedForCurrentCycle = true\n\n        // Get friendly names for the nodes if available\n        let nodeNames = nodeIds.map { nodeId -> String in\n            if let snapshot = clusterStateService?.latestSnapshot,\n                let profile = snapshot.nodeProfiles[nodeId],\n                let friendlyName = profile.friendlyName, !friendlyName.isEmpty\n            {\n                return friendlyName\n            }\n            return String(nodeId.prefix(8))  // Use first 8 chars of node ID as fallback\n        }\n        let machineNames = nodeNames.joined(separator: \", \")\n\n        let alert = NSAlert()\n        alert.messageText = \"Thunderbolt Bridge Loop Detected\"\n        alert.informativeText = \"\"\"\n            A Thunderbolt Bridge loop has been detected between \\(nodeNames.count) machines: \\(machineNames).\n\n            This can cause network packet storms and connectivity issues. Would you like to disable Thunderbolt Bridge on this machine to break the loop?\n            \"\"\"\n        alert.alertStyle = .warning\n        alert.addButton(withTitle: \"Disable Bridge\")\n        alert.addButton(withTitle: \"Not Now\")\n\n        let response = alert.runModal()\n\n        if response == .alertFirstButtonReturn {\n            Task {\n                await disableThunderboltBridge()\n            }\n        }\n    }\n\n    func disableThunderboltBridge() async {\n        Self.logger.info(\"Attempting to disable Thunderbolt Bridge via SCPreferences\")\n        lastError = nil\n\n        do {\n            try await disableThunderboltBridgeWithSCPreferences()\n            Self.logger.info(\"Successfully disabled Thunderbolt Bridge\")\n        } catch {\n            Self.logger.error(\n                \"Failed to disable Thunderbolt Bridge: \\(error.localizedDescription, privacy: .public)\"\n            )\n            lastError = error.localizedDescription\n            showErrorAlert(message: error.localizedDescription)\n        }\n    }\n\n    private func disableThunderboltBridgeWithSCPreferences() async throws {\n        // 1. Create authorization reference\n        var authRef: AuthorizationRef?\n        var status = AuthorizationCreate(nil, nil, [], &authRef)\n        guard status == errAuthorizationSuccess, let authRef = authRef else {\n            throw ThunderboltBridgeError.authorizationFailed\n        }\n\n        defer { AuthorizationFree(authRef, [.destroyRights]) }\n\n        // 2. Request specific network configuration rights\n        let rightName = \"system.services.systemconfiguration.network\"\n        status = rightName.withCString { nameCString in\n            var item = AuthorizationItem(\n                name: nameCString,\n                valueLength: 0,\n                value: nil,\n                flags: 0\n            )\n            return withUnsafeMutablePointer(to: &item) { itemPointer in\n                var rights = AuthorizationRights(count: 1, items: itemPointer)\n                return AuthorizationCopyRights(\n                    authRef,\n                    &rights,\n                    nil,\n                    [.extendRights, .interactionAllowed],\n                    nil\n                )\n            }\n        }\n        guard status == errAuthorizationSuccess else {\n            if status == errAuthorizationCanceled {\n                throw ThunderboltBridgeError.authorizationCanceled\n            }\n            throw ThunderboltBridgeError.authorizationDenied\n        }\n\n        // 3. Create SCPreferences with authorization\n        guard\n            let prefs = SCPreferencesCreateWithAuthorization(\n                kCFAllocatorDefault,\n                \"EXO\" as CFString,\n                nil,\n                authRef\n            )\n        else {\n            throw ThunderboltBridgeError.preferencesCreationFailed\n        }\n\n        // 4. Lock, modify, commit\n        guard SCPreferencesLock(prefs, true) else {\n            throw ThunderboltBridgeError.lockFailed\n        }\n\n        defer {\n            SCPreferencesUnlock(prefs)\n        }\n\n        // 5. Find the Thunderbolt Bridge service dynamically (don't assume the name)\n        guard let targetServiceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName()\n        else {\n            throw ThunderboltBridgeError.serviceNotFound\n        }\n\n        guard let allServices = SCNetworkServiceCopyAll(prefs) as? [SCNetworkService] else {\n            throw ThunderboltBridgeError.servicesNotFound\n        }\n\n        var found = false\n        for service in allServices {\n            if let name = SCNetworkServiceGetName(service) as String?,\n                name == targetServiceName\n            {\n                guard SCNetworkServiceSetEnabled(service, false) else {\n                    throw ThunderboltBridgeError.disableFailed\n                }\n                found = true\n                Self.logger.info(\n                    \"Found and disabled Thunderbolt Bridge service: '\\(targetServiceName)'\")\n                break\n            }\n        }\n\n        if !found {\n            throw ThunderboltBridgeError.serviceNotFound\n        }\n\n        // 6. Commit and apply\n        guard SCPreferencesCommitChanges(prefs) else {\n            throw ThunderboltBridgeError.commitFailed\n        }\n\n        guard SCPreferencesApplyChanges(prefs) else {\n            throw ThunderboltBridgeError.applyFailed\n        }\n    }\n\n    private func showErrorAlert(message: String) {\n        let alert = NSAlert()\n        alert.messageText = \"Failed to Disable Thunderbolt Bridge\"\n        alert.informativeText = message\n        alert.alertStyle = .critical\n        alert.addButton(withTitle: \"OK\")\n        alert.runModal()\n    }\n}\n\nenum ThunderboltBridgeError: LocalizedError {\n    case authorizationFailed\n    case authorizationCanceled\n    case authorizationDenied\n    case preferencesCreationFailed\n    case lockFailed\n    case servicesNotFound\n    case serviceNotFound\n    case disableFailed\n    case commitFailed\n    case applyFailed\n\n    var errorDescription: String? {\n        switch self {\n        case .authorizationFailed:\n            return \"Failed to create authorization\"\n        case .authorizationCanceled:\n            return \"Authorization was canceled by user\"\n        case .authorizationDenied:\n            return \"Authorization was denied\"\n        case .preferencesCreationFailed:\n            return \"Failed to access network preferences\"\n        case .lockFailed:\n            return \"Failed to lock network preferences for modification\"\n        case .servicesNotFound:\n            return \"Could not retrieve network services\"\n        case .serviceNotFound:\n            return \"Thunderbolt Bridge service not found\"\n        case .disableFailed:\n            return \"Failed to disable Thunderbolt Bridge service\"\n        case .commitFailed:\n            return \"Failed to save network configuration changes\"\n        case .applyFailed:\n            return \"Failed to apply network configuration changes\"\n        }\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/ViewModels/InstanceViewModel.swift",
    "content": "import Foundation\n\nstruct DownloadProgressViewModel: Equatable {\n    let downloadedBytes: Int64\n    let totalBytes: Int64\n    let speedBytesPerSecond: Double\n    let etaSeconds: Double?\n    let completedFiles: Int\n    let totalFiles: Int\n\n    var fractionCompleted: Double {\n        guard totalBytes > 0 else { return 0 }\n        return Double(downloadedBytes) / Double(totalBytes)\n    }\n\n    var percentCompleted: Double {\n        fractionCompleted * 100\n    }\n\n    var formattedProgress: String {\n        let downloaded = formatBytes(downloadedBytes)\n        let total = formatBytes(totalBytes)\n        let percent = String(format: \"%.1f\", percentCompleted)\n        return \"\\(downloaded)/\\(total) (\\(percent)%)\"\n    }\n\n    var formattedSpeed: String {\n        \"\\(formatBytes(Int64(speedBytesPerSecond)))/s\"\n    }\n\n    var formattedETA: String? {\n        guard let eta = etaSeconds, eta > 0 else { return nil }\n        let minutes = Int(eta) / 60\n        let seconds = Int(eta) % 60\n        if minutes > 0 {\n            return \"ETA \\(minutes)m \\(seconds)s\"\n        }\n        return \"ETA \\(seconds)s\"\n    }\n\n    private func formatBytes(_ bytes: Int64) -> String {\n        let gb = Double(bytes) / 1_073_741_824.0\n        let mb = Double(bytes) / 1_048_576.0\n        if gb >= 1.0 {\n            return String(format: \"%.2f GB\", gb)\n        }\n        return String(format: \"%.0f MB\", mb)\n    }\n}\n\nstruct InstanceViewModel: Identifiable, Equatable {\n    enum State {\n        case downloading\n        case warmingUp\n        case running\n        case ready\n        case waiting\n        case failed\n        case idle\n        case preparing\n\n        var label: String {\n            switch self {\n            case .downloading: return \"Downloading\"\n            case .warmingUp: return \"Warming Up\"\n            case .running: return \"Running\"\n            case .ready: return \"Ready\"\n            case .waiting: return \"Waiting\"\n            case .failed: return \"Failed\"\n            case .idle: return \"Idle\"\n            case .preparing: return \"Preparing\"\n            }\n        }\n    }\n\n    let id: String\n    let modelName: String\n    let sharding: String?\n    let nodeNames: [String]\n    let state: State\n    let chatTasks: [InstanceTaskViewModel]\n    let downloadProgress: DownloadProgressViewModel?\n\n    var nodeSummary: String {\n        guard !nodeNames.isEmpty else { return \"0 nodes\" }\n        if nodeNames.count == 1 {\n            return nodeNames[0]\n        }\n        if nodeNames.count == 2 {\n            return nodeNames.joined(separator: \", \")\n        }\n        let others = nodeNames.count - 1\n        return \"\\(nodeNames.first ?? \"\") +\\(others)\"\n    }\n}\n\nextension ClusterState {\n    func instanceViewModels() -> [InstanceViewModel] {\n        let chatTasksByInstance = Dictionary(\n            grouping: tasks.values.filter { $0.kind == .chatCompletion && $0.instanceId != nil },\n            by: { $0.instanceId! }\n        )\n\n        return instances.map { entry in\n            let instance = entry.value\n            let modelName = instance.shardAssignments.modelId\n            let nodeToRunner = instance.shardAssignments.nodeToRunner\n            let nodeIds = Array(nodeToRunner.keys)\n            let runnerIds = Array(nodeToRunner.values)\n            let nodeNames = nodeIds.compactMap {\n                nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0\n            }\n            let statuses = runnerIds.compactMap { runners[$0]?.status.lowercased() }\n            let downloadProgress = aggregateDownloadProgress(for: nodeIds)\n            let state = InstanceViewModel.State(\n                statuses: statuses, hasActiveDownload: downloadProgress != nil)\n            let chatTasks = (chatTasksByInstance[entry.key] ?? [])\n                .sorted(by: { $0.sortPriority < $1.sortPriority })\n                .map { InstanceTaskViewModel(task: $0) }\n            return InstanceViewModel(\n                id: entry.key,\n                modelName: modelName,\n                sharding: InstanceViewModel.friendlyShardingName(for: instance.sharding),\n                nodeNames: nodeNames,\n                state: state,\n                chatTasks: chatTasks,\n                downloadProgress: downloadProgress\n            )\n        }\n        .sorted { $0.modelName < $1.modelName }\n    }\n\n    private func aggregateDownloadProgress(for nodeIds: [String]) -> DownloadProgressViewModel? {\n        var totalDownloaded: Int64 = 0\n        var totalSize: Int64 = 0\n        var totalSpeed: Double = 0\n        var maxEtaMs: Int64 = 0\n        var totalCompletedFiles = 0\n        var totalFileCount = 0\n        var hasActiveDownload = false\n\n        for nodeId in nodeIds {\n            guard let nodeDownloads = downloads[nodeId] else { continue }\n            for download in nodeDownloads {\n                guard let progress = download.progress else { continue }\n                hasActiveDownload = true\n                totalDownloaded += progress.downloadedBytes.inBytes\n                totalSize += progress.totalBytes.inBytes\n                totalSpeed += progress.speed ?? 0\n                if let eta = progress.etaMs {\n                    maxEtaMs = max(maxEtaMs, eta)\n                }\n                totalCompletedFiles += progress.completedFiles ?? 0\n                totalFileCount += progress.totalFiles ?? 0\n            }\n        }\n\n        guard hasActiveDownload else { return nil }\n\n        return DownloadProgressViewModel(\n            downloadedBytes: totalDownloaded,\n            totalBytes: totalSize,\n            speedBytesPerSecond: totalSpeed,\n            etaSeconds: maxEtaMs > 0 ? Double(maxEtaMs) / 1000.0 : nil,\n            completedFiles: totalCompletedFiles,\n            totalFiles: totalFileCount\n        )\n    }\n}\n\nextension InstanceViewModel.State {\n    fileprivate init(statuses: [String], hasActiveDownload: Bool = false) {\n        if statuses.contains(where: { $0.contains(\"failed\") }) {\n            self = .failed\n        } else if hasActiveDownload || statuses.contains(where: { $0.contains(\"downloading\") }) {\n            self = .downloading\n        } else if statuses.contains(where: { $0.contains(\"warming\") }) {\n            self = .warmingUp\n        } else if statuses.contains(where: { $0.contains(\"running\") }) {\n            self = .running\n        } else if statuses.contains(where: { $0.contains(\"ready\") || $0.contains(\"loaded\") }) {\n            self = .ready\n        } else if statuses.contains(where: { $0.contains(\"waiting\") }) {\n            self = .waiting\n        } else if statuses.isEmpty {\n            self = .idle\n        } else {\n            self = .preparing\n        }\n    }\n}\n\nextension InstanceViewModel {\n    static func friendlyShardingName(for raw: String?) -> String? {\n        guard let raw else { return nil }\n        switch raw.lowercased() {\n        case \"mlxringinstance\", \"mlxring\":\n            return \"MLX Ring\"\n        case \"mlxibvinstance\", \"mlxibv\":\n            return \"MLX RDMA\"\n        default:\n            return raw\n        }\n    }\n}\n\nstruct InstanceTaskViewModel: Identifiable, Equatable {\n    enum Kind {\n        case chatCompletion\n    }\n\n    let id: String\n    let kind: Kind\n    let status: TaskStatus\n    let modelName: String?\n    let promptPreview: String?\n    let errorMessage: String?\n    let subtitle: String?\n    let parameters: TextGenerationTaskParameters?\n\n    var title: String {\n        switch kind {\n        case .chatCompletion:\n            return \"Chat Completion\"\n        }\n    }\n\n    var detailText: String? {\n        if let errorMessage, !errorMessage.isEmpty {\n            return errorMessage\n        }\n        return promptPreview\n    }\n\n}\n\nextension InstanceTaskViewModel {\n    init(task: ClusterTask) {\n        self.id = task.id\n        self.kind = .chatCompletion\n        self.status = task.status\n        self.modelName = task.modelName\n        self.promptPreview = task.promptPreview\n        self.errorMessage = task.errorMessage\n        self.subtitle = task.modelName\n        self.parameters = task.parameters\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/ViewModels/NodeViewModel.swift",
    "content": "import Foundation\n\nstruct NodeViewModel: Identifiable, Equatable {\n    let id: String\n    let friendlyName: String\n    let model: String\n    let usedRamGB: Double\n    let totalRamGB: Double\n    let gpuUsagePercent: Double\n    let cpuUsagePercent: Double\n    let temperatureCelsius: Double\n    let systenPowerWatts: Double\n\n    var memoryProgress: Double {\n        guard totalRamGB > 0 else { return 0 }\n        return min(max(usedRamGB / totalRamGB, 0), 1)\n    }\n\n    var memoryLabel: String {\n        String(format: \"%.1f / %.1f GB\", usedRamGB, totalRamGB)\n    }\n\n    var temperatureLabel: String {\n        String(format: \"%.0f°C\", temperatureCelsius)\n    }\n\n    var powerLabel: String {\n        systenPowerWatts > 0 ? String(format: \"%.0fW\", systenPowerWatts) : \"—\"\n    }\n\n    var cpuUsageLabel: String {\n        String(format: \"%.0f%%\", cpuUsagePercent)\n    }\n\n    var gpuUsageLabel: String {\n        String(format: \"%.0f%%\", gpuUsagePercent)\n    }\n\n    var deviceIconName: String {\n        let lower = model.lowercased()\n        if lower.contains(\"studio\") {\n            return \"macstudio\"\n        }\n        if lower.contains(\"mini\") {\n            return \"macmini\"\n        }\n        return \"macbook\"\n    }\n}\n\nextension ClusterState {\n    func nodeViewModels() -> [NodeViewModel] {\n        nodeProfiles.map { entry in\n            let profile = entry.value\n            let friendly = profile.friendlyName ?? profile.modelId ?? entry.key\n            let model = profile.modelId ?? \"Unknown\"\n            let totalBytes = Double(profile.memory?.ramTotal?.inBytes ?? 0)\n            let availableBytes = Double(profile.memory?.ramAvailable?.inBytes ?? 0)\n            let usedBytes = max(totalBytes - availableBytes, 0)\n            return NodeViewModel(\n                id: entry.key,\n                friendlyName: friendly,\n                model: model,\n                usedRamGB: usedBytes / 1_073_741_824.0,\n                totalRamGB: totalBytes / 1_073_741_824.0,\n                gpuUsagePercent: (profile.system?.gpuUsage ?? 0) * 100,\n                cpuUsagePercent: (profile.system?.pcpuUsage ?? 0) * 100,\n                temperatureCelsius: profile.system?.temp ?? 0,\n                systenPowerWatts: profile.system?.sysPower ?? 0\n            )\n        }\n        .sorted { $0.friendlyName < $1.friendlyName }\n    }\n}\n\nstruct TopologyEdgeViewModel: Hashable {\n    let sourceId: String\n    let targetId: String\n}\n\nstruct TopologyViewModel {\n    let nodes: [NodeViewModel]\n    let edges: [TopologyEdgeViewModel]\n    let currentNodeId: String?\n}\n\nextension ClusterState {\n    func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {\n        let topologyNodeIds = Set(topology?.nodes ?? [])\n        let allNodes = nodeViewModels().filter {\n            topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)\n        }\n        guard !allNodes.isEmpty else { return nil }\n\n        let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })\n        var orderedNodes: [NodeViewModel] = []\n        if let topologyNodes = topology?.nodes {\n            for nodeId in topologyNodes {\n                if let viewModel = nodesById[nodeId] {\n                    orderedNodes.append(viewModel)\n                }\n            }\n            let seenIds = Set(orderedNodes.map(\\.id))\n            let remaining = allNodes.filter { !seenIds.contains($0.id) }\n            orderedNodes.append(contentsOf: remaining)\n        } else {\n            orderedNodes = allNodes\n        }\n\n        // Rotate so the local node (from /node_id API) is first\n        if let localId = localNodeId,\n            let index = orderedNodes.firstIndex(where: { $0.id == localId })\n        {\n            orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])\n        }\n\n        let nodeIds = Set(orderedNodes.map(\\.id))\n        let edgesArray: [TopologyEdgeViewModel] =\n            topology?.connections.compactMap { connection in\n                guard nodeIds.contains(connection.localNodeId),\n                    nodeIds.contains(connection.sendBackNodeId)\n                else { return nil }\n                return TopologyEdgeViewModel(\n                    sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)\n            } ?? []\n        let edges = Set(edgesArray)\n\n        return TopologyViewModel(\n            nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Views/FirstLaunchPopout.swift",
    "content": "import AppKit\nimport SwiftUI\n\n/// A popover callout anchored to the menu bar icon on every launch,\n/// pointing the user to the web dashboard with an arrow connecting to the icon.\n@MainActor\nfinal class FirstLaunchPopout {\n    private var popover: NSPopover?\n    private var countdownTask: Task<Void, Never>?\n    private static let dashboardURL = \"http://localhost:52415/\"\n\n    /// Called when the user completes onboarding (clicks Open Dashboard or dismisses).\n    var onComplete: (() -> Void)?\n\n    func show() {\n        guard popover == nil else { return }\n\n        // The status bar button may not exist yet on first launch; retry generously.\n        showWithRetry(attemptsRemaining: 15)\n    }\n\n    private func showWithRetry(attemptsRemaining: Int) {\n        guard attemptsRemaining > 0 else {\n            // Exhausted retries — fall back to just opening the dashboard directly.\n            openDashboard()\n            onComplete?()\n            return\n        }\n\n        guard let button = Self.findStatusItemButton() else {\n            DispatchQueue.main.asyncAfter(deadline: .now() + 1.0) { [weak self] in\n                self?.showWithRetry(attemptsRemaining: attemptsRemaining - 1)\n            }\n            return\n        }\n\n        let pop = NSPopover()\n        pop.behavior = .applicationDefined\n        pop.animates = true\n        pop.contentSize = NSSize(width: 280, height: 120)\n        pop.contentViewController = NSHostingController(\n            rootView: WelcomeCalloutView(\n                countdownDuration: 10,\n                onDismiss: { [weak self] in\n                    self?.onComplete?()\n                    self?.dismiss()\n                },\n                onOpen: { [weak self] in\n                    self?.openDashboard()\n                    self?.onComplete?()\n                    self?.dismiss()\n                }\n            )\n        )\n\n        self.popover = pop\n        pop.show(relativeTo: button.bounds, of: button, preferredEdge: .minY)\n\n        // Auto-open dashboard after 10s then dismiss\n        countdownTask = Task {\n            try? await Task.sleep(nanoseconds: 10_000_000_000)\n            if !Task.isCancelled {\n                openDashboard()\n                onComplete?()\n                dismiss()\n            }\n        }\n    }\n\n    func dismiss() {\n        countdownTask?.cancel()\n        countdownTask = nil\n        guard let pop = popover else { return }\n        popover = nil\n        pop.performClose(nil)\n    }\n\n    private func openDashboard() {\n        guard let url = URL(string: Self.dashboardURL) else { return }\n        NSWorkspace.shared.open(url)\n    }\n\n    /// Finds the NSStatusBarButton created by SwiftUI's MenuBarExtra.\n    /// Walks the view hierarchy to find the actual button rather than the content view.\n    private static func findStatusItemButton() -> NSView? {\n        for window in NSApp.windows {\n            let className = NSStringFromClass(type(of: window))\n            // Match NSStatusBarWindow or any internal SwiftUI status bar window\n            guard className.contains(\"StatusBar\") || className.contains(\"MenuBarExtra\") else {\n                continue\n            }\n            if let content = window.contentView {\n                if let button = findButton(in: content) {\n                    return button\n                }\n                // Fall back to the content view itself if it has a non-zero frame\n                if content.frame.width > 0 {\n                    return content\n                }\n            }\n        }\n        return nil\n    }\n\n    /// Recursively searches the view hierarchy for an NSStatusBarButton.\n    private static func findButton(in view: NSView) -> NSView? {\n        let className = NSStringFromClass(type(of: view))\n        if className.contains(\"StatusBarButton\") || className.contains(\"StatusItem\") {\n            return view\n        }\n        for subview in view.subviews {\n            if let found = findButton(in: subview) {\n                return found\n            }\n        }\n        return nil\n    }\n}\n\n/// Minimal welcome callout — friendly pointer, not a wall of text.\n/// Rendered inside the NSPopover which provides its own chrome and arrow.\nprivate struct WelcomeCalloutView: View {\n    let countdownDuration: Int\n    let onDismiss: () -> Void\n    let onOpen: () -> Void\n    @State private var countdown: Int\n    @State private var timerTask: Task<Void, Never>?\n\n    init(countdownDuration: Int, onDismiss: @escaping () -> Void, onOpen: @escaping () -> Void) {\n        self.countdownDuration = countdownDuration\n        self.onDismiss = onDismiss\n        self.onOpen = onOpen\n        self._countdown = State(initialValue: countdownDuration)\n    }\n\n    var body: some View {\n        VStack(alignment: .leading, spacing: 10) {\n            HStack(alignment: .top) {\n                Text(\"EXO is running\")\n                    .font(.system(.headline, design: .rounded))\n                    .fontWeight(.semibold)\n                    .foregroundColor(.primary)\n                Spacer()\n                Button {\n                    onDismiss()\n                } label: {\n                    Image(systemName: \"xmark.circle.fill\")\n                        .font(.system(size: 14))\n                        .foregroundStyle(.tertiary)\n                }\n                .buttonStyle(.plain)\n            }\n\n            Text(\"Run your first model here:\")\n                .font(.system(.subheadline, design: .default))\n                .foregroundColor(.secondary)\n\n            HStack {\n                Button {\n                    onOpen()\n                } label: {\n                    Label(\"Open Dashboard\", systemImage: \"arrow.up.right.square\")\n                        .font(.system(.caption, design: .default))\n                        .fontWeight(.medium)\n                }\n                .buttonStyle(.borderedProminent)\n                .tint(.accentColor)\n                .controlSize(.small)\n\n                Spacer()\n\n                if countdown > 0 {\n                    Text(\"Launching in \\(countdown) secs...\")\n                        .font(.system(.caption2, design: .default))\n                        .foregroundColor(.secondary.opacity(0.6))\n                        .monospacedDigit()\n                }\n            }\n        }\n        .padding(14)\n        .onAppear {\n            startCountdown()\n        }\n        .onDisappear {\n            timerTask?.cancel()\n            timerTask = nil\n        }\n    }\n\n    private func startCountdown() {\n        timerTask = Task {\n            while countdown > 0 {\n                try? await Task.sleep(nanoseconds: 1_000_000_000)\n                if !Task.isCancelled {\n                    countdown -= 1\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Views/InstanceRowView.swift",
    "content": "import SwiftUI\n\nstruct InstanceRowView: View {\n    let instance: InstanceViewModel\n    @State private var animatedTaskIDs: Set<String> = []\n    @State private var infoTask: InstanceTaskViewModel?\n    @State private var showChatTasks = true\n\n    var body: some View {\n        VStack(alignment: .leading, spacing: 6) {\n            HStack(spacing: 8) {\n                VStack(alignment: .leading, spacing: 2) {\n                    Text(instance.modelName)\n                        .font(.subheadline)\n                    Text(instance.nodeSummary)\n                        .font(.caption)\n                        .foregroundColor(.secondary)\n                }\n                Spacer()\n                if let progress = instance.downloadProgress {\n                    downloadStatusView(progress: progress)\n                } else {\n                    statusChip(label: instance.state.label.uppercased(), color: statusColor)\n                }\n            }\n            if let progress = instance.downloadProgress {\n                GeometryReader { geometry in\n                    HStack {\n                        Spacer()\n                        downloadProgressBar(progress: progress)\n                            .frame(width: geometry.size.width * 0.5)\n                    }\n                }\n                .frame(height: 4)\n                .padding(.top, -8)\n                .padding(.bottom, 2)\n                HStack(spacing: 8) {\n                    Text(instance.sharding ?? \"\")\n                        .font(.caption2)\n                        .foregroundColor(.secondary)\n                    Spacer()\n                    downloadSpeedView(progress: progress)\n                }\n            } else {\n                Text(instance.sharding ?? \"\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n            }\n            if !instance.chatTasks.isEmpty {\n                VStack(alignment: .leading, spacing: 4) {\n                    HStack {\n                        Text(\"Chat Tasks\")\n                            .font(.caption2)\n                            .foregroundColor(.secondary)\n                        Text(\"(\\(instance.chatTasks.count))\")\n                            .font(.caption2)\n                            .foregroundColor(.secondary)\n                        Spacer()\n                        collapseButton(isExpanded: $showChatTasks)\n                    }\n                    .animation(nil, value: showChatTasks)\n                    if showChatTasks {\n                        VStack(alignment: .leading, spacing: 8) {\n                            ForEach(instance.chatTasks) { task in\n                                taskRow(for: task, parentModelName: instance.modelName)\n                            }\n                        }\n                        .transition(.opacity)\n                    }\n                }\n                .padding(.top, 4)\n                .animation(.easeInOut(duration: 0.25), value: showChatTasks)\n            }\n        }\n        .padding(.vertical, 6)\n    }\n\n    private var statusColor: Color {\n        switch instance.state {\n        case .downloading: return .blue\n        case .warmingUp: return .orange\n        case .running: return .green\n        case .ready: return .teal\n        case .waiting, .idle: return .gray\n        case .failed: return .red\n        case .preparing: return .secondary\n        }\n    }\n\n    @ViewBuilder\n    private func taskRow(for task: InstanceTaskViewModel, parentModelName: String) -> some View {\n        VStack(alignment: .leading, spacing: 4) {\n            HStack(alignment: .top, spacing: 8) {\n                taskStatusIcon(for: task)\n                VStack(alignment: .leading, spacing: 2) {\n                    Text(\"Chat\")\n                        .font(.caption)\n                        .fontWeight(.semibold)\n                    if let subtitle = task.subtitle,\n                        subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame\n                    {\n                        Text(subtitle)\n                            .font(.caption2)\n                            .foregroundColor(.secondary)\n                    }\n                    if let prompt = task.promptPreview, !prompt.isEmpty {\n                        Text(\"⊙ \\(prompt)\")\n                            .font(.caption2)\n                            .foregroundColor(.secondary)\n                            .lineLimit(2)\n                    }\n                    if task.status == .failed, let error = task.errorMessage, !error.isEmpty {\n                        Text(error)\n                            .font(.caption2)\n                            .foregroundColor(.red)\n                            .lineLimit(3)\n                    }\n                }\n                Spacer(minLength: 6)\n                Button {\n                    infoTask = task\n                } label: {\n                    Image(systemName: \"info.circle\")\n                        .imageScale(.small)\n                }\n                .buttonStyle(.plain)\n                .popover(\n                    item: Binding<InstanceTaskViewModel?>(\n                        get: { infoTask?.id == task.id ? infoTask : nil },\n                        set: { newValue in\n                            if newValue == nil {\n                                infoTask = nil\n                            } else {\n                                infoTask = newValue\n                            }\n                        }\n                    ),\n                    attachmentAnchor: .rect(.bounds),\n                    arrowEdge: .top\n                ) { _ in\n                    TaskDetailView(task: task)\n                        .padding()\n                        .frame(width: 240)\n                }\n            }\n        }\n    }\n\n    private func taskStatusIcon(for task: InstanceTaskViewModel) -> some View {\n        let icon: String\n        let color: Color\n        let animation: Animation?\n\n        switch task.status {\n        case .running:\n            icon = \"arrow.triangle.2.circlepath\"\n            color = .blue\n            animation = Animation.linear(duration: 1).repeatForever(autoreverses: false)\n        case .pending:\n            icon = \"circle.dashed\"\n            color = .secondary\n            animation = nil\n        case .failed:\n            icon = \"exclamationmark.triangle.fill\"\n            color = .red\n            animation = nil\n        case .complete:\n            icon = \"checkmark.circle.fill\"\n            color = .green\n            animation = nil\n        case .unknown:\n            icon = \"questionmark.circle\"\n            color = .secondary\n            animation = nil\n        }\n\n        let image = Image(systemName: icon)\n            .imageScale(.small)\n            .foregroundColor(color)\n\n        if let animation {\n            return AnyView(\n                image\n                    .rotationEffect(.degrees(animatedTaskIDs.contains(task.id) ? 360 : 0))\n                    .onAppear {\n                        if !animatedTaskIDs.contains(task.id) {\n                            animatedTaskIDs.insert(task.id)\n                        }\n                    }\n                    .animation(animation, value: animatedTaskIDs)\n            )\n        }\n\n        return AnyView(image)\n    }\n\n    private func statusChip(label: String, color: Color) -> some View {\n        Text(label)\n            .font(.caption2)\n            .padding(.horizontal, 8)\n            .padding(.vertical, 4)\n            .background(color.opacity(0.15))\n            .foregroundColor(color)\n            .clipShape(Capsule())\n    }\n\n    private func downloadStatusView(progress: DownloadProgressViewModel) -> some View {\n        VStack(alignment: .trailing, spacing: 4) {\n            statusChip(label: \"DOWNLOADING\", color: .blue)\n            Text(progress.formattedProgress)\n                .foregroundColor(.primary)\n        }\n        .font(.caption2)\n    }\n\n    private func downloadSpeedView(progress: DownloadProgressViewModel) -> some View {\n        HStack(spacing: 4) {\n            Text(progress.formattedSpeed)\n            if let eta = progress.formattedETA {\n                Text(\"·\")\n                Text(eta)\n            }\n        }\n        .font(.caption2)\n        .foregroundColor(.secondary)\n    }\n\n    private func downloadProgressBar(progress: DownloadProgressViewModel) -> some View {\n        ProgressView(value: progress.fractionCompleted)\n            .progressViewStyle(.linear)\n            .tint(.blue)\n    }\n\n    private func collapseButton(isExpanded: Binding<Bool>) -> some View {\n        Button {\n            isExpanded.wrappedValue.toggle()\n        } label: {\n            Label(\n                isExpanded.wrappedValue ? \"Hide\" : \"Show\",\n                systemImage: isExpanded.wrappedValue ? \"chevron.up\" : \"chevron.down\"\n            )\n            .labelStyle(.titleAndIcon)\n            .contentTransition(.symbolEffect(.replace))\n        }\n        .buttonStyle(.plain)\n        .font(.caption2)\n    }\n\n    private struct TaskDetailView: View, Identifiable {\n        let task: InstanceTaskViewModel\n        var id: String { task.id }\n\n        var body: some View {\n            ScrollView {\n                VStack(alignment: .leading, spacing: 12) {\n                    parameterSection\n                    messageSection\n                    if let error = task.errorMessage, !error.isEmpty {\n                        detailRow(\n                            icon: \"exclamationmark.triangle.fill\",\n                            title: \"Error\",\n                            value: error,\n                            tint: .red\n                        )\n                    }\n                }\n            }\n        }\n\n        @ViewBuilder\n        private var parameterSection: some View {\n            if let params = task.parameters {\n                VStack(alignment: .leading, spacing: 6) {\n                    Text(\"Parameters\")\n                        .font(.subheadline)\n                    if let temperature = params.temperature {\n                        detailRow(title: \"Temperature\", value: String(format: \"%.1f\", temperature))\n                    }\n                    if let maxTokens = params.maxTokens {\n                        detailRow(title: \"Max Tokens\", value: \"\\(maxTokens)\")\n                    }\n                    if let stream = params.stream {\n                        detailRow(title: \"Stream\", value: stream ? \"On\" : \"Off\")\n                    }\n                    if let topP = params.topP {\n                        detailRow(title: \"Top P\", value: String(format: \"%.2f\", topP))\n                    }\n                }\n            }\n        }\n\n        @ViewBuilder\n        private var messageSection: some View {\n            if let messages = task.parameters?.messages, !messages.isEmpty {\n                VStack(alignment: .leading, spacing: 6) {\n                    Text(\"Messages\")\n                        .font(.subheadline)\n                    ForEach(Array(messages.enumerated()), id: \\.offset) { _, message in\n                        VStack(alignment: .leading, spacing: 2) {\n                            Text(message.role?.capitalized ?? \"Message\")\n                                .font(.caption)\n                                .foregroundColor(.secondary)\n                            if let content = message.content, !content.isEmpty {\n                                Text(content)\n                                    .font(.caption2)\n                                    .foregroundColor(.primary)\n                            }\n                        }\n                        .padding(8)\n                        .background(Color.secondary.opacity(0.08))\n                        .clipShape(RoundedRectangle(cornerRadius: 6))\n                    }\n                }\n            }\n        }\n\n        @ViewBuilder\n        private func detailRow(\n            icon: String? = nil, title: String, value: String, tint: Color = .secondary\n        ) -> some View {\n            HStack(alignment: .firstTextBaseline, spacing: 6) {\n                if let icon {\n                    Image(systemName: icon)\n                        .imageScale(.small)\n                        .foregroundColor(tint)\n                }\n                Text(title)\n                    .font(.caption)\n                    .foregroundColor(.secondary)\n                Spacer()\n                Text(value)\n                    .font(.caption2)\n                    .foregroundColor(.primary)\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Views/NodeDetailView.swift",
    "content": "import SwiftUI\n\nstruct NodeDetailView: View {\n    let node: NodeViewModel\n\n    var body: some View {\n        VStack(alignment: .leading, spacing: 10) {\n            Text(node.friendlyName)\n                .font(.headline)\n            Text(node.model)\n                .font(.caption)\n                .foregroundColor(.secondary)\n            Divider()\n            metricRow(label: \"Memory\", value: node.memoryLabel)\n            ProgressView(value: node.memoryProgress)\n            metricRow(label: \"CPU Usage\", value: node.cpuUsageLabel)\n            metricRow(label: \"GPU Usage\", value: node.gpuUsageLabel)\n            metricRow(label: \"Temperature\", value: node.temperatureLabel)\n            metricRow(label: \"Power\", value: node.powerLabel)\n        }\n        .padding()\n    }\n\n    private func metricRow(label: String, value: String) -> some View {\n        HStack {\n            Text(label)\n                .font(.caption)\n                .foregroundColor(.secondary)\n            Spacer()\n            Text(value)\n                .font(.subheadline)\n        }\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Views/NodeRowView.swift",
    "content": "import SwiftUI\n\nstruct NodeRowView: View {\n    let node: NodeViewModel\n\n    var body: some View {\n        VStack(alignment: .leading, spacing: 4) {\n            HStack {\n                VStack(alignment: .leading) {\n                    Text(node.friendlyName)\n                        .font(.subheadline)\n                    Text(node.memoryLabel)\n                        .font(.caption)\n                        .foregroundColor(.secondary)\n                }\n                Spacer()\n                VStack(alignment: .trailing) {\n                    Text(\"\\(node.gpuUsagePercent, specifier: \"%.0f\")% GPU\")\n                        .font(.caption)\n                    Text(node.temperatureLabel)\n                        .font(.caption2)\n                        .foregroundColor(.secondary)\n                }\n            }\n            ProgressView(value: node.memoryProgress)\n                .progressViewStyle(.linear)\n        }\n        .padding(.vertical, 4)\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Views/SettingsView.swift",
    "content": "import AppKit\nimport SwiftUI\n\n/// Native macOS Settings window following Apple HIG.\n/// Organized into General, Model, Advanced, and About sections.\nstruct SettingsView: View {\n    @EnvironmentObject private var controller: ExoProcessController\n    @EnvironmentObject private var updater: SparkleUpdater\n    @EnvironmentObject private var networkStatusService: NetworkStatusService\n    @EnvironmentObject private var thunderboltBridgeService: ThunderboltBridgeService\n    @EnvironmentObject private var stateService: ClusterStateService\n\n    @State private var pendingNamespace: String = \"\"\n    @State private var pendingHFToken: String = \"\"\n    @State private var pendingEnableImageModels = false\n    @State private var pendingOfflineMode = false\n    @State private var needsRestart = false\n    @State private var bugReportInFlight = false\n    @State private var bugReportMessage: String?\n    @State private var uninstallInProgress = false\n\n    var body: some View {\n        TabView {\n            generalTab\n                .tabItem {\n                    Label(\"General\", systemImage: \"gear\")\n                }\n            modelTab\n                .tabItem {\n                    Label(\"Model\", systemImage: \"cube\")\n                }\n            advancedTab\n                .tabItem {\n                    Label(\"Advanced\", systemImage: \"wrench.and.screwdriver\")\n                }\n            aboutTab\n                .tabItem {\n                    Label(\"About\", systemImage: \"info.circle\")\n                }\n        }\n        .frame(width: 450, height: 400)\n        .onAppear {\n            pendingNamespace = controller.customNamespace\n            pendingHFToken = controller.hfToken\n            pendingEnableImageModels = controller.enableImageModels\n            pendingOfflineMode = controller.offlineMode\n            needsRestart = false\n        }\n    }\n\n    // MARK: - General Tab\n\n    private var generalTab: some View {\n        Form {\n            Section {\n                LabeledContent(\"Cluster Namespace\") {\n                    TextField(\"default\", text: $pendingNamespace)\n                        .textFieldStyle(.roundedBorder)\n                        .frame(width: 200)\n                }\n                Text(\"Nodes with the same namespace form a cluster. Leave empty for default.\")\n                    .font(.caption)\n                    .foregroundColor(.secondary)\n            }\n\n            Section {\n                LabeledContent(\"HuggingFace Token\") {\n                    SecureField(\"optional\", text: $pendingHFToken)\n                        .textFieldStyle(.roundedBorder)\n                        .frame(width: 200)\n                }\n                Text(\"Required for gated models. Get yours at huggingface.co/settings/tokens\")\n                    .font(.caption)\n                    .foregroundColor(.secondary)\n            }\n\n            Section {\n                Toggle(\"Offline Mode\", isOn: $pendingOfflineMode)\n                Text(\"Skip internet checks and use only locally available models.\")\n                    .font(.caption)\n                    .foregroundColor(.secondary)\n            }\n\n            Section {\n                HStack {\n                    Spacer()\n                    Button(\"Save & Restart\") {\n                        applyGeneralSettings()\n                    }\n                    .disabled(!hasGeneralChanges)\n                }\n            }\n        }\n        .formStyle(.grouped)\n        .padding()\n    }\n\n    // MARK: - Model Tab\n\n    private var modelTab: some View {\n        Form {\n            Section {\n                Toggle(\"Enable Image Models (experimental)\", isOn: $pendingEnableImageModels)\n                Text(\"Allow text-to-image and image-to-image models in the model picker.\")\n                    .font(.caption)\n                    .foregroundColor(.secondary)\n            }\n\n            Section {\n                HStack {\n                    Spacer()\n                    Button(\"Save & Restart\") {\n                        applyModelSettings()\n                    }\n                    .disabled(!hasModelChanges)\n                }\n            }\n        }\n        .formStyle(.grouped)\n        .padding()\n    }\n\n    // MARK: - Advanced Tab\n\n    private var advancedTab: some View {\n        Form {\n            Section(\"Onboarding\") {\n                HStack {\n                    VStack(alignment: .leading) {\n                        Text(\"Reset Onboarding\")\n                        Text(\"Opens the dashboard and resets the onboarding wizard.\")\n                            .font(.caption)\n                            .foregroundColor(.secondary)\n                    }\n                    Spacer()\n                    Button(\"Reset\") {\n                        guard let url = URL(string: \"http://localhost:52415/?reset-onboarding\")\n                        else { return }\n                        NSWorkspace.shared.open(url)\n                    }\n                }\n            }\n\n            Section(\"Debug Info\") {\n                LabeledContent(\"Thunderbolt Bridge\") {\n                    Text(thunderboltStatusText)\n                        .foregroundColor(thunderboltStatusColor)\n                }\n\n                VStack(alignment: .leading, spacing: 2) {\n                    clusterThunderboltBridgeView\n                }\n\n                VStack(alignment: .leading, spacing: 2) {\n                    interfaceIpList\n                }\n\n                VStack(alignment: .leading, spacing: 2) {\n                    rdmaStatusView\n                }\n\n                sendBugReportButton\n            }\n\n            Section(\"Danger Zone\") {\n                Button(role: .destructive) {\n                    showUninstallConfirmationAlert()\n                } label: {\n                    HStack {\n                        Text(\"Uninstall EXO\")\n                        Spacer()\n                        Image(systemName: \"trash\")\n                            .imageScale(.small)\n                    }\n                }\n                .disabled(uninstallInProgress)\n            }\n        }\n        .formStyle(.grouped)\n        .padding()\n    }\n\n    // MARK: - About Tab\n\n    private var aboutTab: some View {\n        Form {\n            Section {\n                LabeledContent(\"Version\") {\n                    Text(buildTag)\n                        .textSelection(.enabled)\n                }\n                LabeledContent(\"Commit\") {\n                    Text(buildCommit)\n                        .font(.system(.body, design: .monospaced))\n                        .textSelection(.enabled)\n                }\n            }\n\n            Section {\n                Button(\"Check for Updates\") {\n                    updater.checkForUpdates()\n                }\n            }\n        }\n        .formStyle(.grouped)\n        .padding()\n    }\n\n    // MARK: - Debug Info Views (moved from ContentView)\n\n    private var thunderboltStatusText: String {\n        switch networkStatusService.status.thunderboltBridgeState {\n        case .some(.disabled):\n            return \"Disabled\"\n        case .some(.deleted):\n            return \"Deleted\"\n        case .some(.enabled):\n            return \"Enabled\"\n        case nil:\n            return \"Unknown\"\n        }\n    }\n\n    private var thunderboltStatusColor: Color {\n        switch networkStatusService.status.thunderboltBridgeState {\n        case .some(.disabled), .some(.deleted):\n            return .green\n        case .some(.enabled):\n            return .red\n        case nil:\n            return .secondary\n        }\n    }\n\n    private var clusterThunderboltBridgeView: some View {\n        let bridgeStatuses = stateService.latestSnapshot?.nodeThunderboltBridge ?? [:]\n        let localNodeId = stateService.localNodeId\n        let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]\n\n        return VStack(alignment: .leading, spacing: 1) {\n            if bridgeStatuses.isEmpty {\n                Text(\"Cluster TB Bridge: No data\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n            } else {\n                Text(\"Cluster TB Bridge Status:\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n                ForEach(Array(bridgeStatuses.keys.sorted()), id: \\.self) { nodeId in\n                    if let status = bridgeStatuses[nodeId] {\n                        let nodeName =\n                            nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))\n                        let isLocal = nodeId == localNodeId\n                        let prefix = isLocal ? \"  \\(nodeName) (local):\" : \"  \\(nodeName):\"\n                        let statusText =\n                            !status.exists\n                            ? \"N/A\"\n                            : (status.enabled ? \"Enabled\" : \"Disabled\")\n                        let color: Color =\n                            !status.exists\n                            ? .secondary\n                            : (status.enabled ? .red : .green)\n                        Text(\"\\(prefix) \\(statusText)\")\n                            .font(.caption2)\n                            .foregroundColor(color)\n                    }\n                }\n            }\n        }\n    }\n\n    private var interfaceIpList: some View {\n        let statuses = networkStatusService.status.interfaceStatuses\n        return VStack(alignment: .leading, spacing: 1) {\n            Text(\"Interfaces (en0–en7):\")\n                .font(.caption2)\n                .foregroundColor(.secondary)\n            if statuses.isEmpty {\n                Text(\"  Unknown\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n            } else {\n                ForEach(statuses, id: \\.interfaceName) { status in\n                    let ipText = status.ipAddress ?? \"No IP\"\n                    Text(\"  \\(status.interfaceName): \\(ipText)\")\n                        .font(.caption2)\n                        .foregroundColor(status.ipAddress == nil ? .red : .green)\n                }\n            }\n        }\n    }\n\n    private var rdmaStatusView: some View {\n        let rdmaStatuses = stateService.latestSnapshot?.nodeRdmaCtl ?? [:]\n        let localNodeId = stateService.localNodeId\n        let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]\n        let localDevices = networkStatusService.status.localRdmaDevices\n        let localPorts = networkStatusService.status.localRdmaActivePorts\n\n        return VStack(alignment: .leading, spacing: 1) {\n            if rdmaStatuses.isEmpty {\n                Text(\"Cluster RDMA: No data\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n            } else {\n                Text(\"Cluster RDMA Status:\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n                ForEach(Array(rdmaStatuses.keys.sorted()), id: \\.self) { nodeId in\n                    if let status = rdmaStatuses[nodeId] {\n                        let nodeName =\n                            nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))\n                        let isLocal = nodeId == localNodeId\n                        let prefix = isLocal ? \"  \\(nodeName) (local):\" : \"  \\(nodeName):\"\n                        let statusText = status.enabled ? \"Enabled\" : \"Disabled\"\n                        let color: Color = status.enabled ? .green : .orange\n                        Text(\"\\(prefix) \\(statusText)\")\n                            .font(.caption2)\n                            .foregroundColor(color)\n                    }\n                }\n            }\n            if !localDevices.isEmpty {\n                Text(\"  Local Devices: \\(localDevices.joined(separator: \", \"))\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n            }\n            if !localPorts.isEmpty {\n                Text(\"  Local Active Ports:\")\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n                ForEach(localPorts, id: \\.device) { port in\n                    Text(\"    \\(port.device) port \\(port.port): \\(port.state)\")\n                        .font(.caption2)\n                        .foregroundColor(.green)\n                }\n            }\n        }\n    }\n\n    private var sendBugReportButton: some View {\n        VStack(alignment: .leading, spacing: 4) {\n            Button {\n                Task {\n                    await sendBugReport()\n                }\n            } label: {\n                HStack {\n                    if bugReportInFlight {\n                        ProgressView()\n                            .scaleEffect(0.6)\n                    }\n                    Text(\"Send Bug Report\")\n                        .font(.caption)\n                        .fontWeight(.semibold)\n                    Spacer()\n                }\n            }\n            .disabled(bugReportInFlight)\n\n            if let message = bugReportMessage {\n                Text(message)\n                    .font(.caption2)\n                    .foregroundColor(.secondary)\n                    .fixedSize(horizontal: false, vertical: true)\n            }\n        }\n    }\n\n    // MARK: - Actions\n\n    private func sendBugReport() async {\n        bugReportInFlight = true\n        bugReportMessage = \"Collecting logs...\"\n        let service = BugReportService()\n        do {\n            let outcome = try await service.sendReport(isManual: true)\n            bugReportMessage = outcome.message\n        } catch {\n            bugReportMessage = error.localizedDescription\n        }\n        bugReportInFlight = false\n    }\n\n    private func showUninstallConfirmationAlert() {\n        let alert = NSAlert()\n        alert.messageText = \"Uninstall EXO\"\n        alert.informativeText = \"\"\"\n            This will remove EXO and all its system components:\n\n            • Network configuration daemon\n            • Launch at login registration\n            • EXO network location\n\n            The app will be moved to Trash.\n            \"\"\"\n        alert.alertStyle = .warning\n        alert.addButton(withTitle: \"Uninstall\")\n        alert.addButton(withTitle: \"Cancel\")\n\n        if let uninstallButton = alert.buttons.first {\n            uninstallButton.hasDestructiveAction = true\n        }\n\n        let response = alert.runModal()\n        if response == .alertFirstButtonReturn {\n            performUninstall()\n        }\n    }\n\n    private func performUninstall() {\n        uninstallInProgress = true\n\n        controller.cancelPendingLaunch()\n        controller.stop()\n        stateService.stopPolling()\n\n        DispatchQueue.global(qos: .utility).async {\n            do {\n                try NetworkSetupHelper.uninstall()\n\n                DispatchQueue.main.async {\n                    LaunchAtLoginHelper.disable()\n                    self.moveAppToTrash()\n\n                    DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) {\n                        NSApplication.shared.terminate(nil)\n                    }\n                }\n            } catch {\n                DispatchQueue.main.async {\n                    let errorAlert = NSAlert()\n                    errorAlert.messageText = \"Uninstall Failed\"\n                    errorAlert.informativeText = error.localizedDescription\n                    errorAlert.alertStyle = .critical\n                    errorAlert.addButton(withTitle: \"OK\")\n                    errorAlert.runModal()\n                    self.uninstallInProgress = false\n                }\n            }\n        }\n    }\n\n    private func moveAppToTrash() {\n        guard let appURL = Bundle.main.bundleURL as URL? else { return }\n        do {\n            try FileManager.default.trashItem(at: appURL, resultingItemURL: nil)\n        } catch {\n            // If we can't trash the app, that's OK - user can do it manually\n        }\n    }\n\n    // MARK: - Helpers\n\n    private var hasGeneralChanges: Bool {\n        pendingNamespace != controller.customNamespace || pendingHFToken != controller.hfToken\n            || pendingOfflineMode != controller.offlineMode\n    }\n\n    private var hasModelChanges: Bool {\n        pendingEnableImageModels != controller.enableImageModels\n    }\n\n    private func applyGeneralSettings() {\n        controller.customNamespace = pendingNamespace\n        controller.hfToken = pendingHFToken\n        controller.offlineMode = pendingOfflineMode\n        restartIfRunning()\n    }\n\n    private func applyModelSettings() {\n        controller.enableImageModels = pendingEnableImageModels\n        restartIfRunning()\n    }\n\n    private func restartIfRunning() {\n        if controller.status == .running || controller.status == .starting {\n            controller.restart()\n        }\n    }\n\n    private var buildTag: String {\n        Bundle.main.infoDictionary?[\"EXOBuildTag\"] as? String ?? \"unknown\"\n    }\n\n    private var buildCommit: String {\n        Bundle.main.infoDictionary?[\"EXOBuildCommit\"] as? String ?? \"unknown\"\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Views/SettingsWindowController.swift",
    "content": "import AppKit\nimport SwiftUI\n\n/// Manages a standalone native macOS Settings window.\n/// Ensures only one instance exists and brings it to front on repeated opens.\n@MainActor\nfinal class SettingsWindowController: ObservableObject {\n    private var window: NSWindow?\n\n    func open(\n        controller: ExoProcessController,\n        updater: SparkleUpdater,\n        networkStatusService: NetworkStatusService,\n        thunderboltBridgeService: ThunderboltBridgeService,\n        stateService: ClusterStateService\n    ) {\n        if let existing = window, existing.isVisible {\n            existing.makeKeyAndOrderFront(nil)\n            NSApp.activate(ignoringOtherApps: true)\n            return\n        }\n\n        let settingsView = SettingsView()\n            .environmentObject(controller)\n            .environmentObject(updater)\n            .environmentObject(networkStatusService)\n            .environmentObject(thunderboltBridgeService)\n            .environmentObject(stateService)\n\n        let hostingView = NSHostingView(rootView: settingsView)\n\n        let newWindow = NSWindow(\n            contentRect: NSRect(x: 0, y: 0, width: 450, height: 400),\n            styleMask: [.titled, .closable],\n            backing: .buffered,\n            defer: false\n        )\n        newWindow.title = \"EXO Settings\"\n        newWindow.contentView = hostingView\n        newWindow.center()\n        newWindow.isReleasedWhenClosed = false\n        newWindow.makeKeyAndOrderFront(nil)\n        NSApp.activate(ignoringOtherApps: true)\n\n        window = newWindow\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/Views/TopologyMiniView.swift",
    "content": "import SwiftUI\n\nstruct TopologyMiniView: View {\n    let topology: TopologyViewModel\n\n    var body: some View {\n        VStack(alignment: .leading, spacing: 8) {\n            Text(\"Topology\")\n                .font(.caption)\n                .foregroundColor(.secondary)\n            GeometryReader { geo in\n                ZStack {\n                    connectionLines(in: geo.size)\n                    let positions = positionedNodes(in: geo.size)\n                    ForEach(Array(positions.enumerated()), id: \\.element.node.id) { _, positioned in\n                        NodeGlyphView(node: positioned.node, isCurrent: positioned.isCurrent)\n                            .position(positioned.point)\n                    }\n                }\n            }\n            .frame(height: heightForNodes())\n        }\n    }\n\n    private func positionedNodes(in size: CGSize) -> [PositionedNode] {\n        let nodes = orderedNodesForLayout()\n        guard !nodes.isEmpty else { return [] }\n        var result: [PositionedNode] = []\n        let glyphHeight: CGFloat = 70\n        let rootPoint = CGPoint(x: size.width / 2, y: glyphHeight / 2 + 10)\n        result.append(\n            PositionedNode(\n                node: nodes[0],\n                point: rootPoint,\n                isCurrent: nodes[0].id == topology.currentNodeId\n            )\n        )\n        guard nodes.count > 1 else { return result }\n        let childCount = nodes.count - 1\n        // Larger radius to reduce overlap when several nodes exist\n        let minDimension = min(size.width, size.height)\n        let radius = max(120, minDimension * 0.42)\n        let startAngle = Double.pi * 0.75\n        let endAngle = Double.pi * 0.25\n        let step = childCount == 1 ? 0 : (startAngle - endAngle) / Double(childCount - 1)\n        for (index, node) in nodes.dropFirst().enumerated() {\n            let angle = startAngle - step * Double(index)\n            let x = size.width / 2 + radius * CGFloat(cos(angle))\n            let y = rootPoint.y + radius * CGFloat(sin(angle)) + glyphHeight / 2\n            result.append(\n                PositionedNode(\n                    node: node,\n                    point: CGPoint(x: x, y: y),\n                    isCurrent: node.id == topology.currentNodeId\n                )\n            )\n        }\n        return result\n    }\n\n    private func orderedNodesForLayout() -> [NodeViewModel] {\n        guard let currentId = topology.currentNodeId else {\n            return topology.nodes\n        }\n        guard let currentIndex = topology.nodes.firstIndex(where: { $0.id == currentId }) else {\n            return topology.nodes\n        }\n        if currentIndex == 0 {\n            return topology.nodes\n        }\n        var reordered = topology.nodes\n        let current = reordered.remove(at: currentIndex)\n        reordered.insert(current, at: 0)\n        return reordered\n    }\n\n    private func connectionLines(in size: CGSize) -> some View {\n        let positions = positionedNodes(in: size)\n        let positionById = Dictionary(\n            uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })\n        return Canvas { context, _ in\n            guard !topology.edges.isEmpty else { return }\n            let nodeRadius: CGFloat = 32\n            let arrowLength: CGFloat = 10\n            let arrowSpread: CGFloat = .pi / 7\n            for edge in topology.edges {\n                guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId]\n                else { continue }\n                let dx = end.x - start.x\n                let dy = end.y - start.y\n                let distance = max(CGFloat(hypot(dx, dy)), 1)\n                let ux = dx / distance\n                let uy = dy / distance\n                let adjustedStart = CGPoint(\n                    x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)\n                let adjustedEnd = CGPoint(x: end.x - ux * nodeRadius, y: end.y - uy * nodeRadius)\n\n                var linePath = Path()\n                linePath.move(to: adjustedStart)\n                linePath.addLine(to: adjustedEnd)\n                context.stroke(\n                    linePath,\n                    with: .color(.secondary.opacity(0.3)),\n                    style: StrokeStyle(lineWidth: 1, dash: [4, 4])\n                )\n\n                let angle = atan2(uy, ux)\n                let tip = adjustedEnd\n                let leftWing = CGPoint(\n                    x: tip.x - arrowLength * cos(angle - arrowSpread),\n                    y: tip.y - arrowLength * sin(angle - arrowSpread)\n                )\n                let rightWing = CGPoint(\n                    x: tip.x - arrowLength * cos(angle + arrowSpread),\n                    y: tip.y - arrowLength * sin(angle + arrowSpread)\n                )\n                var arrowPath = Path()\n                arrowPath.move(to: tip)\n                arrowPath.addLine(to: leftWing)\n                arrowPath.move(to: tip)\n                arrowPath.addLine(to: rightWing)\n                context.stroke(\n                    arrowPath,\n                    with: .color(.secondary.opacity(0.5)),\n                    style: StrokeStyle(lineWidth: 1)\n                )\n            }\n        }\n    }\n\n    private func heightForNodes() -> CGFloat {\n        switch topology.nodes.count {\n        case 0...1:\n            return 130\n        case 2...3:\n            return 200\n        default:\n            return 240\n        }\n    }\n\n    private struct PositionedNode {\n        let node: NodeViewModel\n        let point: CGPoint\n        let isCurrent: Bool\n    }\n}\n\nprivate struct NodeGlyphView: View {\n    let node: NodeViewModel\n    let isCurrent: Bool\n\n    var body: some View {\n        VStack(spacing: 2) {\n            Image(systemName: node.deviceIconName)\n                .font(.subheadline)\n            Text(node.friendlyName)\n                .font(.caption2)\n                .lineLimit(1)\n                .foregroundColor(isCurrent ? Color(nsColor: .systemBlue) : .primary)\n            Text(node.memoryLabel)\n                .font(.caption2)\n            HStack(spacing: 3) {\n                Text(node.gpuUsageLabel)\n                Text(node.temperatureLabel)\n            }\n            .foregroundColor(.secondary)\n            .font(.caption2)\n        }\n        .padding(.vertical, 3)\n        .frame(width: 95)\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXO/main.swift",
    "content": "//\n//  main.swift\n//  EXO\n//\n//  Created by Jake Hillion on 2026-02-03.\n//\n\nimport Foundation\n\n/// Command line options for the EXO app\nenum CLICommand {\n    case install\n    case uninstall\n    case help\n    case none\n}\n\n/// Parse command line arguments to determine the CLI command\nfunc parseArguments() -> CLICommand {\n    let args = CommandLine.arguments\n    if args.contains(\"--help\") || args.contains(\"-h\") {\n        return .help\n    }\n    if args.contains(\"--install\") {\n        return .install\n    }\n    if args.contains(\"--uninstall\") {\n        return .uninstall\n    }\n    return .none\n}\n\n/// Print usage information\nfunc printUsage() {\n    let programName = (CommandLine.arguments.first as NSString?)?.lastPathComponent ?? \"EXO\"\n    print(\n        \"\"\"\n        Usage: \\(programName) [OPTIONS]\n\n        Options:\n          --install     Install EXO network configuration (requires root)\n          --uninstall   Uninstall EXO network configuration (requires root)\n          --help, -h    Show this help message\n\n        When run without options, starts the normal GUI application.\n\n        Examples:\n          sudo \\(programName) --install    Install network components as root\n          sudo \\(programName) --uninstall  Remove network components as root\n        \"\"\")\n}\n\n/// Check if running as root\nfunc isRunningAsRoot() -> Bool {\n    return getuid() == 0\n}\n\n// Main entry point\nlet command = parseArguments()\n\nswitch command {\ncase .help:\n    printUsage()\n    exit(0)\n\ncase .install:\n    if !isRunningAsRoot() {\n        fputs(\"Error: --install requires root privileges. Run with sudo.\\n\", stderr)\n        exit(1)\n    }\n    let success = NetworkSetupHelper.installDirectly()\n    exit(success ? 0 : 1)\n\ncase .uninstall:\n    if !isRunningAsRoot() {\n        fputs(\"Error: --uninstall requires root privileges. Run with sudo.\\n\", stderr)\n        exit(1)\n    }\n    let success = NetworkSetupHelper.uninstallDirectly()\n    exit(success ? 0 : 1)\n\ncase .none:\n    // Start normal GUI application\n    EXOApp.main()\n}\n"
  },
  {
    "path": "app/EXO/EXO.xcodeproj/project.pbxproj",
    "content": "// !$*UTF8*$!\n{\n\tarchiveVersion = 1;\n\tclasses = {\n\t};\n\tobjectVersion = 77;\n\tobjects = {\n\n/* Begin PBXBuildFile section */\n\t\tE0140D402ED1F909001F3171 /* exo in Resources */ = {isa = PBXBuildFile; fileRef = E0140D3F2ED1F909001F3171 /* exo */; };\n\t\tE0A1B1002F5A000100000003 /* Sparkle in Frameworks */ = {isa = PBXBuildFile; productRef = E0A1B1002F5A000100000002 /* Sparkle */; };\n/* End PBXBuildFile section */\n\n/* Begin PBXContainerItemProxy section */\n\t\tE0140D212ED1F79B001F3171 /* PBXContainerItemProxy */ = {\n\t\t\tisa = PBXContainerItemProxy;\n\t\t\tcontainerPortal = E0140D072ED1F79A001F3171 /* Project object */;\n\t\t\tproxyType = 1;\n\t\t\tremoteGlobalIDString = E0140D0E2ED1F79A001F3171;\n\t\t\tremoteInfo = EXO;\n\t\t};\n\t\tE0140D2B2ED1F79B001F3171 /* PBXContainerItemProxy */ = {\n\t\t\tisa = PBXContainerItemProxy;\n\t\t\tcontainerPortal = E0140D072ED1F79A001F3171 /* Project object */;\n\t\t\tproxyType = 1;\n\t\t\tremoteGlobalIDString = E0140D0E2ED1F79A001F3171;\n\t\t\tremoteInfo = EXO;\n\t\t};\n/* End PBXContainerItemProxy section */\n\n/* Begin PBXFileReference section */\n\t\tE0140D0F2ED1F79A001F3171 /* EXO.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = EXO.app; sourceTree = BUILT_PRODUCTS_DIR; };\n\t\tE0140D202ED1F79B001F3171 /* EXOTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = EXOTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; };\n\t\tE0140D2A2ED1F79B001F3171 /* EXOUITests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = EXOUITests.xctest; sourceTree = BUILT_PRODUCTS_DIR; };\n\t\tE0140D3F2ED1F909001F3171 /* exo */ = {isa = PBXFileReference; lastKnownFileType = folder; name = exo; path = ../../dist/exo; sourceTree = \"<group>\"; };\n/* End PBXFileReference section */\n\n/* Begin PBXFileSystemSynchronizedRootGroup section */\n\t\tE0140D112ED1F79A001F3171 /* EXO */ = {\n\t\t\tisa = PBXFileSystemSynchronizedRootGroup;\n\t\t\tpath = EXO;\n\t\t\tsourceTree = \"<group>\";\n\t\t};\n\t\tE0140D232ED1F79B001F3171 /* EXOTests */ = {\n\t\t\tisa = PBXFileSystemSynchronizedRootGroup;\n\t\t\tpath = EXOTests;\n\t\t\tsourceTree = \"<group>\";\n\t\t};\n\t\tE0140D2D2ED1F79B001F3171 /* EXOUITests */ = {\n\t\t\tisa = PBXFileSystemSynchronizedRootGroup;\n\t\t\tpath = EXOUITests;\n\t\t\tsourceTree = \"<group>\";\n\t\t};\n/* End PBXFileSystemSynchronizedRootGroup section */\n\n/* Begin PBXFrameworksBuildPhase section */\n\t\tE0140D0C2ED1F79A001F3171 /* Frameworks */ = {\n\t\t\tisa = PBXFrameworksBuildPhase;\n\t\t\tbuildActionMask = 2147483647;\n\t\t\tfiles = (\n\t\t\t\tE0A1B1002F5A000100000003 /* Sparkle in Frameworks */,\n\t\t\t);\n\t\t\trunOnlyForDeploymentPostprocessing = 0;\n\t\t};\n\t\tE0140D1D2ED1F79B001F3171 /* Frameworks */ = {\n\t\t\tisa = PBXFrameworksBuildPhase;\n\t\t\tbuildActionMask = 2147483647;\n\t\t\tfiles = (\n\t\t\t);\n\t\t\trunOnlyForDeploymentPostprocessing = 0;\n\t\t};\n\t\tE0140D272ED1F79B001F3171 /* Frameworks */ = {\n\t\t\tisa = PBXFrameworksBuildPhase;\n\t\t\tbuildActionMask = 2147483647;\n\t\t\tfiles = (\n\t\t\t);\n\t\t\trunOnlyForDeploymentPostprocessing = 0;\n\t\t};\n/* End PBXFrameworksBuildPhase section */\n\n/* Begin PBXGroup section */\n\t\tE0140D062ED1F79A001F3171 = {\n\t\t\tisa = PBXGroup;\n\t\t\tchildren = (\n\t\t\t\tE0140D3F2ED1F909001F3171 /* exo */,\n\t\t\t\tE0140D112ED1F79A001F3171 /* EXO */,\n\t\t\t\tE0140D232ED1F79B001F3171 /* EXOTests */,\n\t\t\t\tE0140D2D2ED1F79B001F3171 /* EXOUITests */,\n\t\t\t\tE0140D102ED1F79A001F3171 /* Products */,\n\t\t\t);\n\t\t\tsourceTree = \"<group>\";\n\t\t};\n\t\tE0140D102ED1F79A001F3171 /* Products */ = {\n\t\t\tisa = PBXGroup;\n\t\t\tchildren = (\n\t\t\t\tE0140D0F2ED1F79A001F3171 /* EXO.app */,\n\t\t\t\tE0140D202ED1F79B001F3171 /* EXOTests.xctest */,\n\t\t\t\tE0140D2A2ED1F79B001F3171 /* EXOUITests.xctest */,\n\t\t\t);\n\t\t\tname = Products;\n\t\t\tsourceTree = \"<group>\";\n\t\t};\n/* End PBXGroup section */\n\n/* Begin PBXNativeTarget section */\n\t\tE0140D0E2ED1F79A001F3171 /* EXO */ = {\n\t\t\tisa = PBXNativeTarget;\n\t\t\tbuildConfigurationList = E0140D342ED1F79B001F3171 /* Build configuration list for PBXNativeTarget \"EXO\" */;\n\t\t\tbuildPhases = (\n\t\t\t\tE0140D0B2ED1F79A001F3171 /* Sources */,\n\t\t\t\tE0140D0C2ED1F79A001F3171 /* Frameworks */,\n\t\t\t\tE0140D0D2ED1F79A001F3171 /* Resources */,\n\t\t\t);\n\t\t\tbuildRules = (\n\t\t\t);\n\t\t\tdependencies = (\n\t\t\t);\n\t\t\tfileSystemSynchronizedGroups = (\n\t\t\t\tE0140D112ED1F79A001F3171 /* EXO */,\n\t\t\t);\n\t\t\tname = EXO;\n\t\t\tpackageProductDependencies = (\n\t\t\t\tE0A1B1002F5A000100000002 /* Sparkle */,\n\t\t\t);\n\t\t\tproductName = EXO;\n\t\t\tproductReference = E0140D0F2ED1F79A001F3171 /* EXO.app */;\n\t\t\tproductType = \"com.apple.product-type.application\";\n\t\t};\n\t\tE0140D1F2ED1F79B001F3171 /* EXOTests */ = {\n\t\t\tisa = PBXNativeTarget;\n\t\t\tbuildConfigurationList = E0140D372ED1F79B001F3171 /* Build configuration list for PBXNativeTarget \"EXOTests\" */;\n\t\t\tbuildPhases = (\n\t\t\t\tE0140D1C2ED1F79B001F3171 /* Sources */,\n\t\t\t\tE0140D1D2ED1F79B001F3171 /* Frameworks */,\n\t\t\t\tE0140D1E2ED1F79B001F3171 /* Resources */,\n\t\t\t);\n\t\t\tbuildRules = (\n\t\t\t);\n\t\t\tdependencies = (\n\t\t\t\tE0140D222ED1F79B001F3171 /* PBXTargetDependency */,\n\t\t\t);\n\t\t\tfileSystemSynchronizedGroups = (\n\t\t\t\tE0140D232ED1F79B001F3171 /* EXOTests */,\n\t\t\t);\n\t\t\tname = EXOTests;\n\t\t\tpackageProductDependencies = (\n\t\t\t);\n\t\t\tproductName = EXOTests;\n\t\t\tproductReference = E0140D202ED1F79B001F3171 /* EXOTests.xctest */;\n\t\t\tproductType = \"com.apple.product-type.bundle.unit-test\";\n\t\t};\n\t\tE0140D292ED1F79B001F3171 /* EXOUITests */ = {\n\t\t\tisa = PBXNativeTarget;\n\t\t\tbuildConfigurationList = E0140D3A2ED1F79B001F3171 /* Build configuration list for PBXNativeTarget \"EXOUITests\" */;\n\t\t\tbuildPhases = (\n\t\t\t\tE0140D262ED1F79B001F3171 /* Sources */,\n\t\t\t\tE0140D272ED1F79B001F3171 /* Frameworks */,\n\t\t\t\tE0140D282ED1F79B001F3171 /* Resources */,\n\t\t\t);\n\t\t\tbuildRules = (\n\t\t\t);\n\t\t\tdependencies = (\n\t\t\t\tE0140D2C2ED1F79B001F3171 /* PBXTargetDependency */,\n\t\t\t);\n\t\t\tfileSystemSynchronizedGroups = (\n\t\t\t\tE0140D2D2ED1F79B001F3171 /* EXOUITests */,\n\t\t\t);\n\t\t\tname = EXOUITests;\n\t\t\tpackageProductDependencies = (\n\t\t\t);\n\t\t\tproductName = EXOUITests;\n\t\t\tproductReference = E0140D2A2ED1F79B001F3171 /* EXOUITests.xctest */;\n\t\t\tproductType = \"com.apple.product-type.bundle.ui-testing\";\n\t\t};\n/* End PBXNativeTarget section */\n\n/* Begin PBXProject section */\n\t\tE0140D072ED1F79A001F3171 /* Project object */ = {\n\t\t\tisa = PBXProject;\n\t\t\tattributes = {\n\t\t\t\tBuildIndependentTargetsInParallel = 1;\n\t\t\t\tLastSwiftUpdateCheck = 1610;\n\t\t\t\tLastUpgradeCheck = 1610;\n\t\t\t\tTargetAttributes = {\n\t\t\t\t\tE0140D0E2ED1F79A001F3171 = {\n\t\t\t\t\t\tCreatedOnToolsVersion = 16.1;\n\t\t\t\t\t};\n\t\t\t\t\tE0140D1F2ED1F79B001F3171 = {\n\t\t\t\t\t\tCreatedOnToolsVersion = 16.1;\n\t\t\t\t\t\tTestTargetID = E0140D0E2ED1F79A001F3171;\n\t\t\t\t\t};\n\t\t\t\t\tE0140D292ED1F79B001F3171 = {\n\t\t\t\t\t\tCreatedOnToolsVersion = 16.1;\n\t\t\t\t\t\tTestTargetID = E0140D0E2ED1F79A001F3171;\n\t\t\t\t\t};\n\t\t\t\t};\n\t\t\t};\n\t\t\tbuildConfigurationList = E0140D0A2ED1F79A001F3171 /* Build configuration list for PBXProject \"EXO\" */;\n\t\t\tdevelopmentRegion = en;\n\t\t\thasScannedForEncodings = 0;\n\t\t\tknownRegions = (\n\t\t\t\ten,\n\t\t\t\tBase,\n\t\t\t);\n\t\t\tmainGroup = E0140D062ED1F79A001F3171;\n\t\t\tminimizedProjectReferenceProxies = 1;\n\t\t\tpackageReferences = (\n\t\t\t\tE0A1B1002F5A000100000001 /* XCRemoteSwiftPackageReference \"Sparkle\" */,\n\t\t\t);\n\t\t\tpreferredProjectObjectVersion = 77;\n\t\t\tproductRefGroup = E0140D102ED1F79A001F3171 /* Products */;\n\t\t\tprojectDirPath = \"\";\n\t\t\tprojectRoot = \"\";\n\t\t\ttargets = (\n\t\t\t\tE0140D0E2ED1F79A001F3171 /* EXO */,\n\t\t\t\tE0140D1F2ED1F79B001F3171 /* EXOTests */,\n\t\t\t\tE0140D292ED1F79B001F3171 /* EXOUITests */,\n\t\t\t);\n\t\t};\n/* End PBXProject section */\n\n/* Begin PBXResourcesBuildPhase section */\n\t\tE0140D0D2ED1F79A001F3171 /* Resources */ = {\n\t\t\tisa = PBXResourcesBuildPhase;\n\t\t\tbuildActionMask = 2147483647;\n\t\t\tfiles = (\n\t\t\t\tE0140D402ED1F909001F3171 /* exo in Resources */,\n\t\t\t);\n\t\t\trunOnlyForDeploymentPostprocessing = 0;\n\t\t};\n\t\tE0140D1E2ED1F79B001F3171 /* Resources */ = {\n\t\t\tisa = PBXResourcesBuildPhase;\n\t\t\tbuildActionMask = 2147483647;\n\t\t\tfiles = (\n\t\t\t);\n\t\t\trunOnlyForDeploymentPostprocessing = 0;\n\t\t};\n\t\tE0140D282ED1F79B001F3171 /* Resources */ = {\n\t\t\tisa = PBXResourcesBuildPhase;\n\t\t\tbuildActionMask = 2147483647;\n\t\t\tfiles = (\n\t\t\t);\n\t\t\trunOnlyForDeploymentPostprocessing = 0;\n\t\t};\n/* End PBXResourcesBuildPhase section */\n\n/* Begin PBXSourcesBuildPhase section */\n\t\tE0140D0B2ED1F79A001F3171 /* Sources */ = {\n\t\t\tisa = PBXSourcesBuildPhase;\n\t\t\tbuildActionMask = 2147483647;\n\t\t\tfiles = (\n\t\t\t);\n\t\t\trunOnlyForDeploymentPostprocessing = 0;\n\t\t};\n\t\tE0140D1C2ED1F79B001F3171 /* Sources */ = {\n\t\t\tisa = PBXSourcesBuildPhase;\n\t\t\tbuildActionMask = 2147483647;\n\t\t\tfiles = (\n\t\t\t);\n\t\t\trunOnlyForDeploymentPostprocessing = 0;\n\t\t};\n\t\tE0140D262ED1F79B001F3171 /* Sources */ = {\n\t\t\tisa = PBXSourcesBuildPhase;\n\t\t\tbuildActionMask = 2147483647;\n\t\t\tfiles = (\n\t\t\t);\n\t\t\trunOnlyForDeploymentPostprocessing = 0;\n\t\t};\n/* End PBXSourcesBuildPhase section */\n\n/* Begin PBXTargetDependency section */\n\t\tE0140D222ED1F79B001F3171 /* PBXTargetDependency */ = {\n\t\t\tisa = PBXTargetDependency;\n\t\t\ttarget = E0140D0E2ED1F79A001F3171 /* EXO */;\n\t\t\ttargetProxy = E0140D212ED1F79B001F3171 /* PBXContainerItemProxy */;\n\t\t};\n\t\tE0140D2C2ED1F79B001F3171 /* PBXTargetDependency */ = {\n\t\t\tisa = PBXTargetDependency;\n\t\t\ttarget = E0140D0E2ED1F79A001F3171 /* EXO */;\n\t\t\ttargetProxy = E0140D2B2ED1F79B001F3171 /* PBXContainerItemProxy */;\n\t\t};\n/* End PBXTargetDependency section */\n\n/* Begin XCBuildConfiguration section */\n\t\tE0140D322ED1F79B001F3171 /* Debug */ = {\n\t\t\tisa = XCBuildConfiguration;\n\t\t\tbuildSettings = {\n\t\t\t\tALWAYS_SEARCH_USER_PATHS = NO;\n\t\t\t\tASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;\n\t\t\t\tCLANG_ANALYZER_NONNULL = YES;\n\t\t\t\tCLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;\n\t\t\t\tCLANG_CXX_LANGUAGE_STANDARD = \"gnu++20\";\n\t\t\t\tCLANG_ENABLE_MODULES = YES;\n\t\t\t\tCLANG_ENABLE_OBJC_ARC = YES;\n\t\t\t\tCLANG_ENABLE_OBJC_WEAK = YES;\n\t\t\t\tCLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;\n\t\t\t\tCLANG_WARN_BOOL_CONVERSION = YES;\n\t\t\t\tCLANG_WARN_COMMA = YES;\n\t\t\t\tCLANG_WARN_CONSTANT_CONVERSION = YES;\n\t\t\t\tCLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;\n\t\t\t\tCLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;\n\t\t\t\tCLANG_WARN_DOCUMENTATION_COMMENTS = YES;\n\t\t\t\tCLANG_WARN_EMPTY_BODY = YES;\n\t\t\t\tCLANG_WARN_ENUM_CONVERSION = YES;\n\t\t\t\tCLANG_WARN_INFINITE_RECURSION = YES;\n\t\t\t\tCLANG_WARN_INT_CONVERSION = YES;\n\t\t\t\tCLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;\n\t\t\t\tCLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;\n\t\t\t\tCLANG_WARN_OBJC_LITERAL_CONVERSION = YES;\n\t\t\t\tCLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;\n\t\t\t\tCLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;\n\t\t\t\tCLANG_WARN_RANGE_LOOP_ANALYSIS = YES;\n\t\t\t\tCLANG_WARN_STRICT_PROTOTYPES = YES;\n\t\t\t\tCLANG_WARN_SUSPICIOUS_MOVE = YES;\n\t\t\t\tCLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;\n\t\t\t\tCLANG_WARN_UNREACHABLE_CODE = YES;\n\t\t\t\tCLANG_WARN__DUPLICATE_METHOD_MATCH = YES;\n\t\t\t\tCOPY_PHASE_STRIP = NO;\n\t\t\t\tDEBUG_INFORMATION_FORMAT = dwarf;\n\t\t\t\tENABLE_STRICT_OBJC_MSGSEND = YES;\n\t\t\t\tENABLE_TESTABILITY = YES;\n\t\t\t\tENABLE_USER_SCRIPT_SANDBOXING = YES;\n\t\t\t\tGCC_C_LANGUAGE_STANDARD = gnu17;\n\t\t\t\tGCC_DYNAMIC_NO_PIC = NO;\n\t\t\t\tGCC_NO_COMMON_BLOCKS = YES;\n\t\t\t\tGCC_OPTIMIZATION_LEVEL = 0;\n\t\t\t\tGCC_PREPROCESSOR_DEFINITIONS = (\n\t\t\t\t\t\"DEBUG=1\",\n\t\t\t\t\t\"$(inherited)\",\n\t\t\t\t);\n\t\t\t\tGCC_WARN_64_TO_32_BIT_CONVERSION = YES;\n\t\t\t\tGCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;\n\t\t\t\tGCC_WARN_UNDECLARED_SELECTOR = YES;\n\t\t\t\tGCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;\n\t\t\t\tGCC_WARN_UNUSED_FUNCTION = YES;\n\t\t\t\tGCC_WARN_UNUSED_VARIABLE = YES;\n\t\t\t\tLOCALIZATION_PREFERS_STRING_CATALOGS = YES;\n\t\t\t\tMACOSX_DEPLOYMENT_TARGET = 15.1;\n\t\t\t\tMTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;\n\t\t\t\tMTL_FAST_MATH = YES;\n\t\t\t\tONLY_ACTIVE_ARCH = YES;\n\t\t\t\tSDKROOT = macosx;\n\t\t\t\tSWIFT_ACTIVE_COMPILATION_CONDITIONS = \"DEBUG $(inherited)\";\n\t\t\t\tSWIFT_OPTIMIZATION_LEVEL = \"-Onone\";\n\t\t\t\tSWIFT_TREAT_WARNINGS_AS_ERRORS = YES;\n\t\t\t\tGCC_TREAT_WARNINGS_AS_ERRORS = YES;\n\t\t\t};\n\t\t\tname = Debug;\n\t\t};\n\t\tE0140D332ED1F79B001F3171 /* Release */ = {\n\t\t\tisa = XCBuildConfiguration;\n\t\t\tbuildSettings = {\n\t\t\t\tALWAYS_SEARCH_USER_PATHS = NO;\n\t\t\t\tASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;\n\t\t\t\tCLANG_ANALYZER_NONNULL = YES;\n\t\t\t\tCLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;\n\t\t\t\tCLANG_CXX_LANGUAGE_STANDARD = \"gnu++20\";\n\t\t\t\tCLANG_ENABLE_MODULES = YES;\n\t\t\t\tCLANG_ENABLE_OBJC_ARC = YES;\n\t\t\t\tCLANG_ENABLE_OBJC_WEAK = YES;\n\t\t\t\tCLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;\n\t\t\t\tCLANG_WARN_BOOL_CONVERSION = YES;\n\t\t\t\tCLANG_WARN_COMMA = YES;\n\t\t\t\tCLANG_WARN_CONSTANT_CONVERSION = YES;\n\t\t\t\tCLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;\n\t\t\t\tCLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;\n\t\t\t\tCLANG_WARN_DOCUMENTATION_COMMENTS = YES;\n\t\t\t\tCLANG_WARN_EMPTY_BODY = YES;\n\t\t\t\tCLANG_WARN_ENUM_CONVERSION = YES;\n\t\t\t\tCLANG_WARN_INFINITE_RECURSION = YES;\n\t\t\t\tCLANG_WARN_INT_CONVERSION = YES;\n\t\t\t\tCLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;\n\t\t\t\tCLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;\n\t\t\t\tCLANG_WARN_OBJC_LITERAL_CONVERSION = YES;\n\t\t\t\tCLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;\n\t\t\t\tCLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;\n\t\t\t\tCLANG_WARN_RANGE_LOOP_ANALYSIS = YES;\n\t\t\t\tCLANG_WARN_STRICT_PROTOTYPES = YES;\n\t\t\t\tCLANG_WARN_SUSPICIOUS_MOVE = YES;\n\t\t\t\tCLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;\n\t\t\t\tCLANG_WARN_UNREACHABLE_CODE = YES;\n\t\t\t\tCLANG_WARN__DUPLICATE_METHOD_MATCH = YES;\n\t\t\t\tCOPY_PHASE_STRIP = NO;\n\t\t\t\tDEBUG_INFORMATION_FORMAT = \"dwarf-with-dsym\";\n\t\t\t\tENABLE_NS_ASSERTIONS = NO;\n\t\t\t\tENABLE_STRICT_OBJC_MSGSEND = YES;\n\t\t\t\tENABLE_USER_SCRIPT_SANDBOXING = YES;\n\t\t\t\tGCC_C_LANGUAGE_STANDARD = gnu17;\n\t\t\t\tGCC_NO_COMMON_BLOCKS = YES;\n\t\t\t\tGCC_WARN_64_TO_32_BIT_CONVERSION = YES;\n\t\t\t\tGCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;\n\t\t\t\tGCC_WARN_UNDECLARED_SELECTOR = YES;\n\t\t\t\tGCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;\n\t\t\t\tGCC_WARN_UNUSED_FUNCTION = YES;\n\t\t\t\tGCC_WARN_UNUSED_VARIABLE = YES;\n\t\t\t\tLOCALIZATION_PREFERS_STRING_CATALOGS = YES;\n\t\t\t\tMACOSX_DEPLOYMENT_TARGET = 15.1;\n\t\t\t\tMTL_ENABLE_DEBUG_INFO = NO;\n\t\t\t\tMTL_FAST_MATH = YES;\n\t\t\t\tSDKROOT = macosx;\n\t\t\t\tSWIFT_COMPILATION_MODE = wholemodule;\n\t\t\t\tSWIFT_TREAT_WARNINGS_AS_ERRORS = YES;\n\t\t\t\tGCC_TREAT_WARNINGS_AS_ERRORS = YES;\n\t\t\t};\n\t\t\tname = Release;\n\t\t};\n\t\tE0140D352ED1F79B001F3171 /* Debug */ = {\n\t\t\tisa = XCBuildConfiguration;\n\t\t\tbuildSettings = {\n\t\t\t\tASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;\n\t\t\t\tASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;\n\t\t\t\tCODE_SIGN_ENTITLEMENTS = EXO/EXO.entitlements;\n\t\t\t\tCODE_SIGN_STYLE = Automatic;\n\t\t\t\tCOMBINE_HIDPI_IMAGES = YES;\n\t\t\t\tCURRENT_PROJECT_VERSION = 1;\n\t\t\t\tDEVELOPMENT_ASSET_PATHS = \"\\\"EXO/Preview Content\\\"\";\n\t\t\t\tENABLE_PREVIEWS = YES;\n\t\t\t\tGENERATE_INFOPLIST_FILE = YES;\n\t\t\t\tINFOPLIST_FILE = EXO/Info.plist;\n\t\t\t\tINFOPLIST_KEY_LSUIElement = YES;\n\t\t\t\tINFOPLIST_KEY_EXOBuildCommit = \"$(EXO_BUILD_COMMIT)\";\n\t\t\t\tINFOPLIST_KEY_EXOBuildTag = \"$(EXO_BUILD_TAG)\";\n\t\t\t\tINFOPLIST_KEY_NSAppleEventsUsageDescription = \"EXO needs to run a signed network setup script with administrator privileges.\";\n\t\t\t\tINFOPLIST_KEY_NSHumanReadableCopyright = \"\";\n\t\t\t\tINFOPLIST_KEY_SUEnableAutomaticChecks = YES;\n\t\t\t\tINFOPLIST_KEY_SUFeedURL = \"$(SPARKLE_FEED_URL)\";\n\t\t\t\tINFOPLIST_KEY_SUPublicEDKey = \"$(SPARKLE_ED25519_PUBLIC)\";\n\t\t\t\tLD_RUNPATH_SEARCH_PATHS = (\n\t\t\t\t\t\"$(inherited)\",\n\t\t\t\t\t\"@executable_path/../Frameworks\",\n\t\t\t\t);\n\t\t\t\tMARKETING_VERSION = 1.0.1;\n\t\t\t\tPRODUCT_BUNDLE_IDENTIFIER = exolabs.EXO;\n\t\t\t\tPRODUCT_NAME = \"$(TARGET_NAME)\";\n\t\t\t\tEXO_BUILD_COMMIT = local;\n\t\t\t\tEXO_BUILD_TAG = dev;\n\t\t\t\tSPARKLE_ED25519_PUBLIC = \"\";\n\t\t\t\tSPARKLE_FEED_URL = \"\";\n\t\t\t\tSWIFT_EMIT_LOC_STRINGS = YES;\n\t\t\t\tSWIFT_VERSION = 5.0;\n\t\t\t};\n\t\t\tname = Debug;\n\t\t};\n\t\tE0140D362ED1F79B001F3171 /* Release */ = {\n\t\t\tisa = XCBuildConfiguration;\n\t\t\tbuildSettings = {\n\t\t\t\tASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;\n\t\t\t\tASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;\n\t\t\t\tCODE_SIGN_ENTITLEMENTS = EXO/EXO.entitlements;\n\t\t\t\tCODE_SIGN_STYLE = Automatic;\n\t\t\t\tCOMBINE_HIDPI_IMAGES = YES;\n\t\t\t\tCURRENT_PROJECT_VERSION = 1;\n\t\t\t\tDEVELOPMENT_ASSET_PATHS = \"\\\"EXO/Preview Content\\\"\";\n\t\t\t\tENABLE_PREVIEWS = YES;\n\t\t\t\tGENERATE_INFOPLIST_FILE = YES;\n\t\t\t\tINFOPLIST_FILE = EXO/Info.plist;\n\t\t\t\tINFOPLIST_KEY_LSUIElement = YES;\n\t\t\t\tINFOPLIST_KEY_EXOBuildCommit = \"$(EXO_BUILD_COMMIT)\";\n\t\t\t\tINFOPLIST_KEY_EXOBuildTag = \"$(EXO_BUILD_TAG)\";\n\t\t\t\tINFOPLIST_KEY_NSAppleEventsUsageDescription = \"EXO needs to run a signed network setup script with administrator privileges.\";\n\t\t\t\tINFOPLIST_KEY_NSHumanReadableCopyright = \"\";\n\t\t\t\tINFOPLIST_KEY_SUEnableAutomaticChecks = YES;\n\t\t\t\tINFOPLIST_KEY_SUFeedURL = \"$(SPARKLE_FEED_URL)\";\n\t\t\t\tINFOPLIST_KEY_SUPublicEDKey = \"$(SPARKLE_ED25519_PUBLIC)\";\n\t\t\t\tLD_RUNPATH_SEARCH_PATHS = (\n\t\t\t\t\t\"$(inherited)\",\n\t\t\t\t\t\"@executable_path/../Frameworks\",\n\t\t\t\t);\n\t\t\t\tMARKETING_VERSION = 1.0.1;\n\t\t\t\tPRODUCT_BUNDLE_IDENTIFIER = exolabs.EXO;\n\t\t\t\tPRODUCT_NAME = \"$(TARGET_NAME)\";\n\t\t\t\tEXO_BUILD_COMMIT = local;\n\t\t\t\tEXO_BUILD_TAG = dev;\n\t\t\t\tSPARKLE_ED25519_PUBLIC = \"\";\n\t\t\t\tSPARKLE_FEED_URL = \"\";\n\t\t\t\tSWIFT_EMIT_LOC_STRINGS = YES;\n\t\t\t\tSWIFT_VERSION = 5.0;\n\t\t\t};\n\t\t\tname = Release;\n\t\t};\n\t\tE0140D382ED1F79B001F3171 /* Debug */ = {\n\t\t\tisa = XCBuildConfiguration;\n\t\t\tbuildSettings = {\n\t\t\t\tBUNDLE_LOADER = \"$(TEST_HOST)\";\n\t\t\t\tCODE_SIGN_STYLE = Automatic;\n\t\t\t\tCURRENT_PROJECT_VERSION = 1;\n\t\t\t\tGENERATE_INFOPLIST_FILE = YES;\n\t\t\t\tMACOSX_DEPLOYMENT_TARGET = 15.1;\n\t\t\t\tMARKETING_VERSION = 1.0;\n\t\t\t\tPRODUCT_BUNDLE_IDENTIFIER = exolabs.EXOTests;\n\t\t\t\tPRODUCT_NAME = \"$(TARGET_NAME)\";\n\t\t\t\tSWIFT_EMIT_LOC_STRINGS = NO;\n\t\t\t\tSWIFT_VERSION = 5.0;\n\t\t\t\tTEST_HOST = \"$(BUILT_PRODUCTS_DIR)/EXO.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO\";\n\t\t\t};\n\t\t\tname = Debug;\n\t\t};\n\t\tE0140D392ED1F79B001F3171 /* Release */ = {\n\t\t\tisa = XCBuildConfiguration;\n\t\t\tbuildSettings = {\n\t\t\t\tBUNDLE_LOADER = \"$(TEST_HOST)\";\n\t\t\t\tCODE_SIGN_STYLE = Automatic;\n\t\t\t\tCURRENT_PROJECT_VERSION = 1;\n\t\t\t\tGENERATE_INFOPLIST_FILE = YES;\n\t\t\t\tMACOSX_DEPLOYMENT_TARGET = 15.1;\n\t\t\t\tMARKETING_VERSION = 1.0;\n\t\t\t\tPRODUCT_BUNDLE_IDENTIFIER = exolabs.EXOTests;\n\t\t\t\tPRODUCT_NAME = \"$(TARGET_NAME)\";\n\t\t\t\tSWIFT_EMIT_LOC_STRINGS = NO;\n\t\t\t\tSWIFT_VERSION = 5.0;\n\t\t\t\tTEST_HOST = \"$(BUILT_PRODUCTS_DIR)/EXO.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO\";\n\t\t\t};\n\t\t\tname = Release;\n\t\t};\n\t\tE0140D3B2ED1F79B001F3171 /* Debug */ = {\n\t\t\tisa = XCBuildConfiguration;\n\t\t\tbuildSettings = {\n\t\t\t\tCODE_SIGN_STYLE = Automatic;\n\t\t\t\tCURRENT_PROJECT_VERSION = 1;\n\t\t\t\tGENERATE_INFOPLIST_FILE = YES;\n\t\t\t\tMARKETING_VERSION = 1.0;\n\t\t\t\tPRODUCT_BUNDLE_IDENTIFIER = exolabs.EXOUITests;\n\t\t\t\tPRODUCT_NAME = \"$(TARGET_NAME)\";\n\t\t\t\tSWIFT_EMIT_LOC_STRINGS = NO;\n\t\t\t\tSWIFT_VERSION = 5.0;\n\t\t\t\tTEST_TARGET_NAME = EXO;\n\t\t\t};\n\t\t\tname = Debug;\n\t\t};\n\t\tE0140D3C2ED1F79B001F3171 /* Release */ = {\n\t\t\tisa = XCBuildConfiguration;\n\t\t\tbuildSettings = {\n\t\t\t\tCODE_SIGN_STYLE = Automatic;\n\t\t\t\tCURRENT_PROJECT_VERSION = 1;\n\t\t\t\tGENERATE_INFOPLIST_FILE = YES;\n\t\t\t\tMARKETING_VERSION = 1.0;\n\t\t\t\tPRODUCT_BUNDLE_IDENTIFIER = exolabs.EXOUITests;\n\t\t\t\tPRODUCT_NAME = \"$(TARGET_NAME)\";\n\t\t\t\tSWIFT_EMIT_LOC_STRINGS = NO;\n\t\t\t\tSWIFT_VERSION = 5.0;\n\t\t\t\tTEST_TARGET_NAME = EXO;\n\t\t\t};\n\t\t\tname = Release;\n\t\t};\n/* End XCBuildConfiguration section */\n\n/* Begin XCConfigurationList section */\n\t\tE0140D0A2ED1F79A001F3171 /* Build configuration list for PBXProject \"EXO\" */ = {\n\t\t\tisa = XCConfigurationList;\n\t\t\tbuildConfigurations = (\n\t\t\t\tE0140D322ED1F79B001F3171 /* Debug */,\n\t\t\t\tE0140D332ED1F79B001F3171 /* Release */,\n\t\t\t);\n\t\t\tdefaultConfigurationIsVisible = 0;\n\t\t\tdefaultConfigurationName = Release;\n\t\t};\n\t\tE0140D342ED1F79B001F3171 /* Build configuration list for PBXNativeTarget \"EXO\" */ = {\n\t\t\tisa = XCConfigurationList;\n\t\t\tbuildConfigurations = (\n\t\t\t\tE0140D352ED1F79B001F3171 /* Debug */,\n\t\t\t\tE0140D362ED1F79B001F3171 /* Release */,\n\t\t\t);\n\t\t\tdefaultConfigurationIsVisible = 0;\n\t\t\tdefaultConfigurationName = Release;\n\t\t};\n\t\tE0140D372ED1F79B001F3171 /* Build configuration list for PBXNativeTarget \"EXOTests\" */ = {\n\t\t\tisa = XCConfigurationList;\n\t\t\tbuildConfigurations = (\n\t\t\t\tE0140D382ED1F79B001F3171 /* Debug */,\n\t\t\t\tE0140D392ED1F79B001F3171 /* Release */,\n\t\t\t);\n\t\t\tdefaultConfigurationIsVisible = 0;\n\t\t\tdefaultConfigurationName = Release;\n\t\t};\n\t\tE0140D3A2ED1F79B001F3171 /* Build configuration list for PBXNativeTarget \"EXOUITests\" */ = {\n\t\t\tisa = XCConfigurationList;\n\t\t\tbuildConfigurations = (\n\t\t\t\tE0140D3B2ED1F79B001F3171 /* Debug */,\n\t\t\t\tE0140D3C2ED1F79B001F3171 /* Release */,\n\t\t\t);\n\t\t\tdefaultConfigurationIsVisible = 0;\n\t\t\tdefaultConfigurationName = Release;\n\t\t};\n/* End XCConfigurationList section */\n\n/* Begin XCRemoteSwiftPackageReference section */\n\t\tE0A1B1002F5A000100000001 /* XCRemoteSwiftPackageReference \"Sparkle\" */ = {\n\t\t\tisa = XCRemoteSwiftPackageReference;\n\t\t\trepositoryURL = \"https://github.com/sparkle-project/Sparkle.git\";\n\t\t\trequirement = {\n\t\t\t\tkind = upToNextMajorVersion;\n\t\t\t\tminimumVersion = 2.9.0-beta.1;\n\t\t\t};\n\t\t};\n/* End XCRemoteSwiftPackageReference section */\n\n/* Begin XCSwiftPackageProductDependency section */\n\t\tE0A1B1002F5A000100000002 /* Sparkle */ = {\n\t\t\tisa = XCSwiftPackageProductDependency;\n\t\t\tpackage = E0A1B1002F5A000100000001 /* XCRemoteSwiftPackageReference \"Sparkle\" */;\n\t\t\tproductName = Sparkle;\n\t\t};\n/* End XCSwiftPackageProductDependency section */\n\t};\n\trootObject = E0140D072ED1F79A001F3171 /* Project object */;\n}\n"
  },
  {
    "path": "app/EXO/EXO.xcodeproj/project.xcworkspace/contents.xcworkspacedata",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<Workspace\n   version = \"1.0\">\n   <FileRef\n      location = \"self:\">\n   </FileRef>\n</Workspace>\n"
  },
  {
    "path": "app/EXO/EXO.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved",
    "content": "{\n  \"originHash\" : \"5751fcbe53b64441ed73aceb16987d6b3fc3ebc666cb9ec2de1f6a2d441f2515\",\n  \"pins\" : [\n    {\n      \"identity\" : \"sparkle\",\n      \"kind\" : \"remoteSourceControl\",\n      \"location\" : \"https://github.com/sparkle-project/Sparkle.git\",\n      \"state\" : {\n        \"revision\" : \"e641adb41915a8409895e2e30666aa64e487b637\",\n        \"version\" : \"2.9.0-beta.1\"\n      }\n    }\n  ],\n  \"version\" : 3\n}\n"
  },
  {
    "path": "app/EXO/EXO.xcodeproj/xcshareddata/xcschemes/EXO.xcscheme",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<Scheme\n   LastUpgradeVersion = \"1610\"\n   version = \"1.7\">\n   <BuildAction\n      parallelizeBuildables = \"YES\"\n      buildImplicitDependencies = \"YES\"\n      buildArchitectures = \"Automatic\">\n      <BuildActionEntries>\n         <BuildActionEntry\n            buildForTesting = \"YES\"\n            buildForRunning = \"YES\"\n            buildForProfiling = \"YES\"\n            buildForArchiving = \"YES\"\n            buildForAnalyzing = \"YES\">\n            <BuildableReference\n               BuildableIdentifier = \"primary\"\n               BlueprintIdentifier = \"E0140D0E2ED1F79A001F3171\"\n               BuildableName = \"EXO.app\"\n               BlueprintName = \"EXO\"\n               ReferencedContainer = \"container:EXO.xcodeproj\">\n            </BuildableReference>\n         </BuildActionEntry>\n      </BuildActionEntries>\n   </BuildAction>\n   <TestAction\n      buildConfiguration = \"Debug\"\n      selectedDebuggerIdentifier = \"Xcode.DebuggerFoundation.Debugger.LLDB\"\n      selectedLauncherIdentifier = \"Xcode.DebuggerFoundation.Launcher.LLDB\"\n      shouldUseLaunchSchemeArgsEnv = \"YES\"\n      shouldAutocreateTestPlan = \"YES\">\n      <Testables>\n         <TestableReference\n            skipped = \"NO\"\n            parallelizable = \"YES\">\n            <BuildableReference\n               BuildableIdentifier = \"primary\"\n               BlueprintIdentifier = \"E0140D1F2ED1F79B001F3171\"\n               BuildableName = \"EXOTests.xctest\"\n               BlueprintName = \"EXOTests\"\n               ReferencedContainer = \"container:EXO.xcodeproj\">\n            </BuildableReference>\n         </TestableReference>\n         <TestableReference\n            skipped = \"NO\"\n            parallelizable = \"YES\">\n            <BuildableReference\n               BuildableIdentifier = \"primary\"\n               BlueprintIdentifier = \"E0140D292ED1F79B001F3171\"\n               BuildableName = \"EXOUITests.xctest\"\n               BlueprintName = \"EXOUITests\"\n               ReferencedContainer = \"container:EXO.xcodeproj\">\n            </BuildableReference>\n         </TestableReference>\n      </Testables>\n   </TestAction>\n   <LaunchAction\n      buildConfiguration = \"Debug\"\n      selectedDebuggerIdentifier = \"Xcode.DebuggerFoundation.Debugger.LLDB\"\n      selectedLauncherIdentifier = \"Xcode.DebuggerFoundation.Launcher.LLDB\"\n      launchStyle = \"0\"\n      useCustomWorkingDirectory = \"NO\"\n      ignoresPersistentStateOnLaunch = \"NO\"\n      debugDocumentVersioning = \"YES\"\n      debugServiceExtension = \"internal\"\n      allowLocationSimulation = \"YES\">\n      <BuildableProductRunnable\n         runnableDebuggingMode = \"0\">\n         <BuildableReference\n            BuildableIdentifier = \"primary\"\n            BlueprintIdentifier = \"E0140D0E2ED1F79A001F3171\"\n            BuildableName = \"EXO.app\"\n            BlueprintName = \"EXO\"\n            ReferencedContainer = \"container:EXO.xcodeproj\">\n         </BuildableReference>\n      </BuildableProductRunnable>\n      <EnvironmentVariables>\n         <EnvironmentVariable\n            key = \"EXO_BUG_AWS_ACCESS_KEY_ID\"\n            value = \"AKIAYEKP5EMXTOBYDGHX\"\n            isEnabled = \"YES\">\n         </EnvironmentVariable>\n         <EnvironmentVariable\n            key = \"EXO_BUG_AWS_SECRET_ACCESS_KEY\"\n            value = \"Ep5gIlUZ1o8ssTLQwmyy34yPGfTPEYQ4evE8NdPE\"\n            isEnabled = \"YES\">\n         </EnvironmentVariable>\n      </EnvironmentVariables>\n   </LaunchAction>\n   <ProfileAction\n      buildConfiguration = \"Release\"\n      shouldUseLaunchSchemeArgsEnv = \"YES\"\n      savedToolIdentifier = \"\"\n      useCustomWorkingDirectory = \"NO\"\n      debugDocumentVersioning = \"YES\">\n      <BuildableProductRunnable\n         runnableDebuggingMode = \"0\">\n         <BuildableReference\n            BuildableIdentifier = \"primary\"\n            BlueprintIdentifier = \"E0140D0E2ED1F79A001F3171\"\n            BuildableName = \"EXO.app\"\n            BlueprintName = \"EXO\"\n            ReferencedContainer = \"container:EXO.xcodeproj\">\n         </BuildableReference>\n      </BuildableProductRunnable>\n   </ProfileAction>\n   <AnalyzeAction\n      buildConfiguration = \"Debug\">\n   </AnalyzeAction>\n   <ArchiveAction\n      buildConfiguration = \"Release\"\n      revealArchiveInOrganizer = \"YES\">\n   </ArchiveAction>\n</Scheme>\n"
  },
  {
    "path": "app/EXO/EXOTests/EXOTests.swift",
    "content": "//\n//  EXOTests.swift\n//  EXOTests\n//\n//  Created by Sami Khan on 2025-11-22.\n//\n\nimport Testing\n\n@testable import EXO\n\nstruct EXOTests {\n\n    @Test func example() async throws {\n        // Write your test here and use APIs like `#expect(...)` to check expected conditions.\n    }\n\n}\n"
  },
  {
    "path": "app/EXO/EXOUITests/EXOUITests.swift",
    "content": "//\n//  EXOUITests.swift\n//  EXOUITests\n//\n//  Created by Sami Khan on 2025-11-22.\n//\n\nimport XCTest\n\nfinal class EXOUITests: XCTestCase {\n\n    override func setUpWithError() throws {\n        // Put setup code here. This method is called before the invocation of each test method in the class.\n\n        // In UI tests it is usually best to stop immediately when a failure occurs.\n        continueAfterFailure = false\n\n        // In UI tests it’s important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this.\n    }\n\n    override func tearDownWithError() throws {\n        // Put teardown code here. This method is called after the invocation of each test method in the class.\n    }\n\n    @MainActor\n    func testExample() throws {\n        // UI tests must launch the application that they test.\n        let app = XCUIApplication()\n        app.launch()\n\n        // Use XCTAssert and related functions to verify your tests produce the correct results.\n    }\n\n    @MainActor\n    func testLaunchPerformance() throws {\n        if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 7.0, *) {\n            // This measures how long it takes to launch your application.\n            measure(metrics: [XCTApplicationLaunchMetric()]) {\n                XCUIApplication().launch()\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "app/EXO/EXOUITests/EXOUITestsLaunchTests.swift",
    "content": "//\n//  EXOUITestsLaunchTests.swift\n//  EXOUITests\n//\n//  Created by Sami Khan on 2025-11-22.\n//\n\nimport XCTest\n\nfinal class EXOUITestsLaunchTests: XCTestCase {\n\n    override class var runsForEachTargetApplicationUIConfiguration: Bool {\n        true\n    }\n\n    override func setUpWithError() throws {\n        continueAfterFailure = false\n    }\n\n    @MainActor\n    func testLaunch() throws {\n        let app = XCUIApplication()\n        app.launch()\n\n        // Insert steps here to perform after app launch but before taking a screenshot,\n        // such as logging into a test account or navigating somewhere in the app\n\n        let attachment = XCTAttachment(screenshot: app.screenshot())\n        attachment.name = \"Launch Screen\"\n        attachment.lifetime = .keepAlways\n        add(attachment)\n    }\n}\n"
  },
  {
    "path": "app/EXO/uninstall-exo.sh",
    "content": "#!/usr/bin/env bash\n#\n# EXO Uninstaller Script\n#\n# This script removes all EXO system components that persist after deleting the app.\n# Run with: sudo ./uninstall-exo.sh\n#\n# Components removed:\n# - LaunchDaemon: /Library/LaunchDaemons/io.exo.networksetup.plist\n# - Network script: /Library/Application Support/EXO/\n# - Log files: /var/log/io.exo.networksetup.*\n# - Network location: \"exo\"\n# - Launch at login registration\n#\n\nset -euo pipefail\n\nLABEL=\"io.exo.networksetup\"\nSCRIPT_DEST=\"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh\"\nPLIST_DEST=\"/Library/LaunchDaemons/io.exo.networksetup.plist\"\nLOG_OUT=\"/var/log/${LABEL}.log\"\nLOG_ERR=\"/var/log/${LABEL}.err.log\"\nAPP_BUNDLE_ID=\"io.exo.EXO\"\n\n# Colors for output\nRED='\\033[0;31m'\nGREEN='\\033[0;32m'\nYELLOW='\\033[1;33m'\nNC='\\033[0m' # No Color\n\necho_info() {\n  echo -e \"${GREEN}[INFO]${NC} $1\"\n}\n\necho_warn() {\n  echo -e \"${YELLOW}[WARN]${NC} $1\"\n}\n\necho_error() {\n  echo -e \"${RED}[ERROR]${NC} $1\"\n}\n\n# Check if running as root\nif [[ $EUID -ne 0 ]]; then\n  echo_error \"This script must be run as root (use sudo)\"\n  exit 1\nfi\n\necho \"\"\necho \"========================================\"\necho \"        EXO Uninstaller\"\necho \"========================================\"\necho \"\"\n\n# Unload the LaunchDaemon if running\necho_info \"Stopping network setup daemon...\"\nif launchctl list | grep -q \"$LABEL\"; then\n  launchctl bootout system/\"$LABEL\" 2>/dev/null || true\n  echo_info \"Daemon stopped\"\nelse\n  echo_warn \"Daemon was not running\"\nfi\n\n# Remove LaunchDaemon plist\nif [[ -f $PLIST_DEST ]]; then\n  rm -f \"$PLIST_DEST\"\n  echo_info \"Removed LaunchDaemon plist\"\nelse\n  echo_warn \"LaunchDaemon plist not found (already removed?)\"\nfi\n\n# Remove the script and parent directory\nif [[ -f $SCRIPT_DEST ]]; then\n  rm -f \"$SCRIPT_DEST\"\n  echo_info \"Removed network setup script\"\nelse\n  echo_warn \"Network setup script not found (already removed?)\"\nfi\n\n# Remove EXO directory if empty\nif [[ -d \"/Library/Application Support/EXO\" ]]; then\n  rmdir \"/Library/Application Support/EXO\" 2>/dev/null &&\n    echo_info \"Removed EXO support directory\" ||\n    echo_warn \"EXO support directory not empty, leaving in place\"\nfi\n\n# Remove log files\nif [[ -f $LOG_OUT ]] || [[ -f $LOG_ERR ]]; then\n  rm -f \"$LOG_OUT\" \"$LOG_ERR\"\n  echo_info \"Removed log files\"\nelse\n  echo_warn \"Log files not found (already removed?)\"\nfi\n\n# Switch back to Automatic network location\necho_info \"Restoring network configuration...\"\nif networksetup -listlocations | grep -q \"^Automatic$\"; then\n  networksetup -switchtolocation Automatic 2>/dev/null || true\n  echo_info \"Switched to Automatic network location\"\nelse\n  echo_warn \"Automatic network location not found\"\nfi\n\n# Delete the exo network location if it exists\nif networksetup -listlocations | grep -q \"^exo$\"; then\n  networksetup -deletelocation exo 2>/dev/null || true\n  echo_info \"Deleted 'exo' network location\"\nelse\n  echo_warn \"'exo' network location not found (already removed?)\"\nfi\n\n# Re-enable Thunderbolt Bridge if it exists\nif networksetup -listnetworkservices 2>/dev/null | grep -q \"Thunderbolt Bridge\"; then\n  networksetup -setnetworkserviceenabled \"Thunderbolt Bridge\" on 2>/dev/null || true\n  echo_info \"Re-enabled Thunderbolt Bridge\"\nfi\n\n# Note about launch at login registration\n# SMAppService-based login items cannot be removed from a shell script.\n# They can only be unregistered from within the app itself or manually via System Settings.\necho_warn \"Launch at login must be removed manually:\"\necho_warn \"  System Settings → General → Login Items → Remove EXO\"\n\n# Check if EXO.app exists in common locations\nAPP_FOUND=false\nfor app_path in \"/Applications/EXO.app\" \"$HOME/Applications/EXO.app\"; do\n  if [[ -d $app_path ]]; then\n    if [[ $APP_FOUND == false ]]; then\n      echo \"\"\n      APP_FOUND=true\n    fi\n    echo_warn \"EXO.app found at: $app_path\"\n    echo_warn \"You may want to move it to Trash manually.\"\n  fi\ndone\n\necho \"\"\necho \"========================================\"\necho_info \"EXO uninstall complete!\"\necho \"========================================\"\necho \"\"\necho \"The following have been removed:\"\necho \"  • Network setup LaunchDaemon\"\necho \"  • Network configuration script\"\necho \"  • Log files\"\necho \"  • 'exo' network location\"\necho \"\"\necho \"Your network has been restored to use the 'Automatic' location.\"\necho \"Thunderbolt Bridge has been re-enabled (if present).\"\necho \"\"\necho \"Manual step required:\"\necho \"  Remove EXO from Login Items in System Settings → General → Login Items\"\necho \"\"\n"
  },
  {
    "path": "bench/bench.toml",
    "content": "# Canary benchmark manifest\n#\n# Lists the suite files to include. Each file defines benchmarks\n# with shared constraints, topology, and default args.\ninclude = [\"single-m3-ultra.toml\"]\n"
  },
  {
    "path": "bench/eval_configs/models.toml",
    "content": "# Model evaluation configurations for exo_eval.\n#\n# Each [[model]] entry uses `patterns` — a list of substrings matched\n# against the model_id. First matching entry wins.\n#\n# Required fields:\n#   name, patterns, reasoning\n#\n# Optional per-model overrides (CLI flags take priority over these):\n#   temperature, top_p, max_tokens, reasoning_effort\n#\n# Fallback defaults (when no per-model config):\n#   reasoning:     temperature=1.0, max_tokens=131072, reasoning_effort=\"high\"\n#   non-reasoning: temperature=0.0, max_tokens=16384\n#\n# All per-model values below are sourced from official model cards,\n# generation_config.json files, and vendor documentation.\n\n# ─── Qwen3.5 (Feb 2026) ─────────────────────────────────────────────\n# Source: HuggingFace model cards (Qwen/Qwen3.5-*)\n# 35B-A3B thinking general: temp=1.0, top_p=0.95, top_k=20\n# 397B thinking: temp=0.6, top_p=0.95, top_k=20\n# Non-thinking: temp=0.7, top_p=0.8, top_k=20\n# max_tokens: 32768 general, 81920 for complex math/code\n\n[[model]]\nname = \"Qwen3.5 2B\"\npatterns = [\"Qwen3.5-2B\"]\nreasoning = true\ntemperature = 0.6\ntop_p = 0.95\nmax_tokens = 81920\n\n[[model]]\nname = \"Qwen3.5 9B\"\npatterns = [\"Qwen3.5-9B\"]\nreasoning = true\ntemperature = 0.6\ntop_p = 0.95\nmax_tokens = 81920\n\n[[model]]\nname = \"Qwen3.5 27B\"\npatterns = [\"Qwen3.5-27B\"]\nreasoning = true\ntemperature = 0.6\ntop_p = 0.95\nmax_tokens = 81920\n\n[[model]]\nname = \"Qwen3.5 35B A3B\"\npatterns = [\"Qwen3.5-35B-A3B\"]\nreasoning = true\ntemperature = 1.0\ntop_p = 0.95\nmax_tokens = 81920\n\n[[model]]\nname = \"Qwen3.5 122B A10B\"\npatterns = [\"Qwen3.5-122B-A10B\"]\nreasoning = true\ntemperature = 0.6\ntop_p = 0.95\nmax_tokens = 81920\n\n[[model]]\nname = \"Qwen3.5 397B A17B\"\npatterns = [\"Qwen3.5-397B-A17B\"]\nreasoning = true\ntemperature = 0.6\ntop_p = 0.95\nmax_tokens = 81920\n\n# ─── Qwen3 (Apr 2025) ───────────────────────────────────────────────\n# Source: HuggingFace model cards (Qwen/Qwen3-*)\n# Thinking: temp=0.6, top_p=0.95, top_k=20\n# Non-thinking: temp=0.7, top_p=0.8, top_k=20\n# max_tokens: 32768 general, 38912 for complex math/code\n\n[[model]]\nname = \"Qwen3 0.6B\"\npatterns = [\"Qwen3-0.6B\"]\nreasoning = true\ntemperature = 0.6\ntop_p = 0.95\nmax_tokens = 38912\n\n[[model]]\nname = \"Qwen3 30B A3B\"\npatterns = [\"Qwen3-30B-A3B\"]\nreasoning = true\ntemperature = 0.6\ntop_p = 0.95\nmax_tokens = 38912\n\n[[model]]\nname = \"Qwen3 235B A22B\"\npatterns = [\"Qwen3-235B-A22B\"]\nreasoning = true\ntemperature = 0.6\ntop_p = 0.95\nmax_tokens = 38912\n\n[[model]]\nname = \"Qwen3 Next 80B Thinking\"\npatterns = [\"Qwen3-Next-80B-A3B-Thinking\"]\nreasoning = true\ntemperature = 0.6\ntop_p = 0.95\nmax_tokens = 38912\n\n[[model]]\nname = \"Qwen3 Next 80B Instruct\"\npatterns = [\"Qwen3-Next-80B-A3B-Instruct\"]\nreasoning = false\ntemperature = 0.7\ntop_p = 0.8\nmax_tokens = 16384\n\n[[model]]\nname = \"Qwen3 Coder 480B\"\npatterns = [\"Qwen3-Coder-480B\"]\nreasoning = false\ntemperature = 0.7\ntop_p = 0.8\nmax_tokens = 16384\n\n[[model]]\nname = \"Qwen3 Coder Next\"\npatterns = [\"Qwen3-Coder-Next\"]\nreasoning = false\ntemperature = 0.7\ntop_p = 0.8\nmax_tokens = 16384\n\n# ─── GPT-OSS (OpenAI) ───────────────────────────────────────────────\n# Source: OpenAI GitHub README + HuggingFace discussion #21\n# temp=1.0, top_p=1.0, NO top_k, NO repetition_penalty\n# reasoning_effort supported: low/medium/high\n\n[[model]]\nname = \"GPT-OSS 20B\"\npatterns = [\"gpt-oss-20b\"]\nreasoning = true\ntemperature = 1.0\ntop_p = 1.0\n\n[[model]]\nname = \"GPT-OSS 120B\"\npatterns = [\"gpt-oss-120b\"]\nreasoning = true\ntemperature = 1.0\ntop_p = 1.0\n\n# ─── DeepSeek ────────────────────────────────────────────────────────\n# Source: https://api-docs.deepseek.com/quick_start/parameter_settings\n# Coding/Math: temp=0.0, General: temp=1.3, Creative: temp=1.5\n# NOTE: DeepSeek API applies nonlinear temp mapping. These are API values.\n# When running model directly: API temp 1.0 = model temp 0.3\n# We use temp=0.0 for eval (coding/math focus).\n\n[[model]]\nname = \"DeepSeek V3.1\"\npatterns = [\"DeepSeek-V3.1\"]\nreasoning = true\ntemperature = 0.0\n\n# ─── GLM (ZhipuAI / THUDM) ──────────────────────────────────────────\n# Source: HuggingFace model cards + generation_config.json + docs.z.ai\n# GLM 4.5+: temp=1.0, top_p=0.95\n# Reasoning tasks: 131072 max_tokens; coding/SWE tasks: temp=0.7\n\n[[model]]\nname = \"GLM-5\"\npatterns = [\"GLM-5\"]\nreasoning = true\ntemperature = 1.0\ntop_p = 0.95\nmax_tokens = 131072\n\n[[model]]\nname = \"GLM 4.5 Air\"\npatterns = [\"GLM-4.5-Air\"]\nreasoning = true\ntemperature = 1.0\ntop_p = 0.95\n\n[[model]]\nname = \"GLM 4.7\"\npatterns = [\"GLM-4.7-\"]\nreasoning = true\ntemperature = 1.0\ntop_p = 0.95\nmax_tokens = 131072\n# Note: matches both GLM-4.7 and GLM-4.7-Flash\n\n# ─── Kimi (Moonshot AI) ─────────────────────────────────────────────\n# Source: HuggingFace model cards (moonshotai/Kimi-K2-*)\n# K2-Instruct: temp=0.6\n# K2-Thinking: temp=1.0, max_length=262144\n# K2.5: thinking temp=1.0, top_p=0.95; instant temp=0.6, top_p=0.95\n\n[[model]]\nname = \"Kimi K2 Thinking\"\npatterns = [\"Kimi-K2-Thinking\"]\nreasoning = true\ntemperature = 1.0\nmax_tokens = 131072\n\n[[model]]\nname = \"Kimi K2.5\"\npatterns = [\"Kimi-K2.5\"]\nreasoning = true\ntemperature = 1.0\ntop_p = 0.95\nmax_tokens = 131072\n\n[[model]]\nname = \"Kimi K2 Instruct\"\npatterns = [\"Kimi-K2-Instruct\"]\nreasoning = false\ntemperature = 0.6\n\n# ─── MiniMax ─────────────────────────────────────────────────────────\n# Source: HuggingFace model cards + generation_config.json\n# All models: temp=1.0, top_p=0.95, top_k=40\n\n[[model]]\nname = \"MiniMax M2.5\"\npatterns = [\"MiniMax-M2.5\"]\nreasoning = true\ntemperature = 1.0\ntop_p = 0.95\n\n[[model]]\nname = \"MiniMax M2.1\"\npatterns = [\"MiniMax-M2.1\"]\nreasoning = true\ntemperature = 1.0\ntop_p = 0.95\n\n# ─── Step (StepFun) ─────────────────────────────────────────────────\n# Source: HuggingFace model card (stepfun-ai/Step-3.5-Flash)\n# Reasoning: temp=1.0, top_p=0.95\n# General chat: temp=0.6, top_p=0.95\n# We use reasoning settings for eval.\n\n[[model]]\nname = \"Step 3.5 Flash\"\npatterns = [\"Step-3.5-Flash\"]\nreasoning = true\ntemperature = 1.0\ntop_p = 0.95\n\n# ─── Llama (Meta) ───────────────────────────────────────────────────\n# Source: generation_config.json + meta-llama/llama-models generation.py\n# All variants: temp=0.6, top_p=0.9\n\n[[model]]\nname = \"Llama 3.2 1B\"\npatterns = [\"Llama-3.2-1B\"]\nreasoning = false\ntemperature = 0.6\ntop_p = 0.9\n\n[[model]]\nname = \"Llama 3.2 3B\"\npatterns = [\"Llama-3.2-3B\"]\nreasoning = false\ntemperature = 0.6\ntop_p = 0.9\n\n[[model]]\nname = \"Llama 3.1 8B\"\npatterns = [\"Llama-3.1-8B\", \"Meta-Llama-3.1-8B\"]\nreasoning = false\ntemperature = 0.6\ntop_p = 0.9\n\n[[model]]\nname = \"Llama 3.1 70B\"\npatterns = [\"Llama-3.1-70B\", \"Meta-Llama-3.1-70B\"]\nreasoning = false\ntemperature = 0.6\ntop_p = 0.9\n\n[[model]]\nname = \"Llama 3.3 70B\"\npatterns = [\"Llama-3.3-70B\", \"llama-3.3-70b\"]\nreasoning = false\ntemperature = 0.6\ntop_p = 0.9\n"
  },
  {
    "path": "bench/eval_tool_calls.py",
    "content": "# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false\nfrom __future__ import annotations\n\nimport argparse\nimport contextlib\nimport json\nimport os\nimport sys\nimport time\nimport tomllib\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Any, Literal\n\nimport httpx\nfrom harness import (\n    ExoClient,\n    ExoHttpError,\n    add_common_instance_args,\n    instance_id_from_instance,\n    nodes_used_in_instance,\n    resolve_model_short_id,\n    run_planning_phase,\n    settle_and_fetch_placements,\n    wait_for_instance_gone,\n    wait_for_instance_ready,\n)\n\nSCENARIOS_PATH = Path(__file__).parent / \"scenarios.toml\"\n\n\n@dataclass\nclass Scenario:\n    name: str\n    description: str\n    messages: list[dict[str, Any]]\n    tools: list[dict[str, Any]]\n    expect_tool_call: bool\n    expected_function: str | None = None\n    required_arg_keys: list[str] | None = None\n    tool_result: str | None = None\n    nested_array_key: str | None = None\n    required_item_keys: list[str] | None = None\n\n\ndef load_scenarios(path: Path) -> list[Scenario]:\n    with open(path, \"rb\") as f:\n        data = tomllib.load(f)\n\n    tools_data = data.get(\"tools\", {})\n    all_tools: list[dict[str, Any]] = []\n    tool_by_name: dict[str, dict[str, Any]] = {}\n    for name, defn in tools_data.items():\n        tool: dict[str, Any] = {\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": name,\n                \"description\": defn.get(\"description\", \"\"),\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": defn.get(\"properties\", {}),\n                    \"required\": defn.get(\"required\", []),\n                },\n            },\n        }\n        all_tools.append(tool)\n        tool_by_name[name] = tool\n\n    scenarios: list[Scenario] = []\n    for s in data.get(\"scenarios\", []):\n        if \"tools\" in s:\n            scenario_tools = [tool_by_name[t] for t in s[\"tools\"]]\n        else:\n            scenario_tools = list(all_tools)\n\n        messages: list[dict[str, Any]] = []\n        for msg in s.get(\"messages\", []):\n            m: dict[str, Any] = {\"role\": msg[\"role\"]}\n            if \"content\" in msg:\n                m[\"content\"] = msg[\"content\"]\n            if \"tool_calls\" in msg:\n                m[\"tool_calls\"] = [\n                    {\n                        \"id\": tc[\"id\"],\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": tc[\"name\"],\n                            \"arguments\": json.dumps(tc[\"arguments\"]),\n                        },\n                    }\n                    for tc in msg[\"tool_calls\"]\n                ]\n            if \"tool_call_id\" in msg:\n                m[\"tool_call_id\"] = msg[\"tool_call_id\"]\n            messages.append(m)\n\n        tool_result: str | None = None\n        if \"tool_result\" in s:\n            tool_result = json.dumps(s[\"tool_result\"])\n\n        scenarios.append(\n            Scenario(\n                name=s[\"name\"],\n                description=s[\"description\"],\n                messages=messages,\n                tools=scenario_tools,\n                expect_tool_call=s[\"expect_tool_call\"],\n                expected_function=s.get(\"expected_function\"),\n                required_arg_keys=s.get(\"required_arg_keys\"),\n                tool_result=tool_result,\n                nested_array_key=s.get(\"nested_array_key\"),\n                required_item_keys=s.get(\"required_item_keys\"),\n            )\n        )\n\n    return scenarios\n\n\nApiName = Literal[\"openai\", \"claude\", \"responses\"]\n\n\n@dataclass\nclass ParsedResponse:\n    finish_reason: str  # \"tool_calls\" | \"stop\" | ...\n    has_tool_call: bool\n    tool_call: dict[str, str] | None  # {\"id\": ..., \"name\": ..., \"arguments\": ...}\n    content: str | None\n\n\n@dataclass\nclass ScenarioResult:\n    name: str\n    api: str\n    phase: str  # \"tool_call\" or \"follow_up\"\n    passed: bool\n    checks: dict[str, bool] = field(default_factory=dict)\n    error: str | None = None\n    latency_ms: float = 0.0\n\n\ndef validate_args(args_str: str, required_keys: list[str]) -> tuple[bool, str | None]:\n    \"\"\"Parse JSON arguments and check required keys exist.\"\"\"\n    try:\n        args = json.loads(args_str)\n    except (json.JSONDecodeError, TypeError) as exc:\n        return False, f\"Invalid JSON: {exc}\"\n    if not isinstance(args, dict):\n        return False, f\"Expected dict, got {type(args).__name__}\"\n    missing = [k for k in required_keys if k not in args]\n    if missing:\n        return False, f\"Missing keys: {missing}\"\n    return True, None\n\n\ndef validate_nested_args(\n    args_str: str,\n    array_key: str,\n    required_item_keys: list[str],\n) -> tuple[bool, str | None]:\n    \"\"\"Check that args[array_key] is a list of objects with required keys.\"\"\"\n    try:\n        args = json.loads(args_str)\n    except (json.JSONDecodeError, TypeError) as exc:\n        return False, f\"Invalid JSON: {exc}\"\n    if not isinstance(args, dict):\n        return False, f\"Expected dict, got {type(args).__name__}\"\n    arr = args.get(array_key)\n    if not isinstance(arr, list):\n        return False, f\"'{array_key}' is not an array (got {type(arr).__name__})\"\n    if len(arr) == 0:\n        return False, f\"'{array_key}' is empty\"\n    for i, item in enumerate(arr):\n        if not isinstance(item, dict):\n            return (\n                False,\n                f\"'{array_key}[{i}]' is not an object (got {type(item).__name__})\",\n            )\n        missing = [k for k in required_item_keys if k not in item]\n        if missing:\n            return False, f\"'{array_key}[{i}]' missing keys: {missing}\"\n    return True, None\n\n\ndef call_api(\n    client: httpx.Client,\n    host: str,\n    port: int,\n    path: str,\n    body: dict[str, Any],\n    timeout: float,\n) -> tuple[dict[str, Any], float]:\n    \"\"\"POST to http://{host}:{port}{path}, return (response_json, latency_ms).\"\"\"\n    url = f\"http://{host}:{port}{path}\"\n    t0 = time.monotonic()\n    resp = client.post(url, json=body, timeout=timeout)\n    latency = (time.monotonic() - t0) * 1000\n    resp.raise_for_status()\n    return resp.json(), latency\n\n\ndef _openai_build_request(\n    model: str,\n    messages: list[dict[str, Any]],\n    tools: list[dict[str, Any]],\n) -> tuple[str, dict[str, Any]]:\n    \"\"\"Build request for /v1/chat/completions.\"\"\"\n    body: dict[str, Any] = {\n        \"model\": model,\n        \"messages\": messages,\n        \"tools\": tools,\n        \"max_tokens\": 16384,\n        \"temperature\": 0.0,\n    }\n    return \"/v1/chat/completions\", body\n\n\ndef _openai_parse_response(data: dict[str, Any]) -> ParsedResponse:\n    \"\"\"Parse OpenAI Chat Completions response into common format.\"\"\"\n    choice = data[\"choices\"][0]\n    finish_reason = choice.get(\"finish_reason\", \"\")\n    message = choice.get(\"message\", {})\n    tool_calls = message.get(\"tool_calls\")\n    content = message.get(\"content\")\n\n    has_tool_call = isinstance(tool_calls, list) and len(tool_calls) > 0\n    tool_call_info: dict[str, str] | None = None\n    if has_tool_call:\n        tc = tool_calls[0]\n        fn = tc.get(\"function\", {})\n        tool_call_info = {\n            \"id\": tc.get(\"id\", \"call_0\"),\n            \"name\": fn.get(\"name\", \"\"),\n            \"arguments\": fn.get(\"arguments\", \"{}\"),\n        }\n\n    return ParsedResponse(\n        finish_reason=finish_reason,\n        has_tool_call=has_tool_call,\n        tool_call=tool_call_info,\n        content=content,\n    )\n\n\ndef _openai_build_followup(\n    messages: list[dict[str, Any]],\n    tools: list[dict[str, Any]],\n    model: str,\n    parsed: ParsedResponse,\n    tool_result: str,\n) -> tuple[str, dict[str, Any]]:\n    \"\"\"Build multi-turn follow-up for OpenAI Chat Completions.\"\"\"\n    assert parsed.tool_call is not None\n    tc = parsed.tool_call\n    followup_messages: list[dict[str, Any]] = list(messages) + [\n        {\n            \"role\": \"assistant\",\n            \"tool_calls\": [\n                {\n                    \"id\": tc[\"id\"],\n                    \"type\": \"function\",\n                    \"function\": {\n                        \"name\": tc[\"name\"],\n                        \"arguments\": tc[\"arguments\"],\n                    },\n                }\n            ],\n        },\n        {\n            \"role\": \"tool\",\n            \"tool_call_id\": tc[\"id\"],\n            \"content\": tool_result,\n        },\n    ]\n    body: dict[str, Any] = {\n        \"model\": model,\n        \"messages\": followup_messages,\n        \"tools\": tools,\n        \"max_tokens\": 16384,\n        \"temperature\": 0.0,\n    }\n    return \"/v1/chat/completions\", body\n\n\ndef _claude_translate_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:\n    \"\"\"Translate OpenAI-format tools to Claude format.\"\"\"\n    claude_tools: list[dict[str, Any]] = []\n    for tool in tools:\n        fn = tool[\"function\"]\n        claude_tools.append(\n            {\n                \"name\": fn[\"name\"],\n                \"description\": fn.get(\"description\", \"\"),\n                \"input_schema\": fn.get(\"parameters\", {}),\n            }\n        )\n    return claude_tools\n\n\ndef _claude_translate_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:\n    \"\"\"Translate OpenAI-format messages to Claude Messages format.\"\"\"\n    claude_messages: list[dict[str, Any]] = []\n\n    for msg in messages:\n        role = msg[\"role\"]\n\n        if role == \"user\":\n            claude_messages.append(\n                {\n                    \"role\": \"user\",\n                    \"content\": msg[\"content\"],\n                }\n            )\n        elif role == \"assistant\":\n            content_blocks: list[dict[str, Any]] = []\n            text_content = msg.get(\"content\")\n            if text_content and isinstance(text_content, str) and text_content.strip():\n                content_blocks.append({\"type\": \"text\", \"text\": text_content})\n            tool_calls = msg.get(\"tool_calls\")\n            if tool_calls:\n                for tc in tool_calls:\n                    fn = tc.get(\"function\", {})\n                    args_str = fn.get(\"arguments\", \"{}\")\n                    try:\n                        args_dict = json.loads(args_str)\n                    except (json.JSONDecodeError, TypeError):\n                        args_dict = {}\n                    content_blocks.append(\n                        {\n                            \"type\": \"tool_use\",\n                            \"id\": tc.get(\"id\", \"call_0\"),\n                            \"name\": fn.get(\"name\", \"\"),\n                            \"input\": args_dict,\n                        }\n                    )\n            if not content_blocks:\n                content_blocks.append({\"type\": \"text\", \"text\": \"\"})\n            claude_messages.append(\n                {\n                    \"role\": \"assistant\",\n                    \"content\": content_blocks,\n                }\n            )\n        elif role == \"tool\":\n            claude_messages.append(\n                {\n                    \"role\": \"user\",\n                    \"content\": [\n                        {\n                            \"type\": \"tool_result\",\n                            \"tool_use_id\": msg.get(\"tool_call_id\", \"call_0\"),\n                            \"content\": msg.get(\"content\", \"\"),\n                        }\n                    ],\n                }\n            )\n        elif role == \"system\":\n            pass\n\n    return claude_messages\n\n\ndef _claude_build_request(\n    model: str,\n    messages: list[dict[str, Any]],\n    tools: list[dict[str, Any]],\n) -> tuple[str, dict[str, Any]]:\n    \"\"\"Build request for /v1/messages.\"\"\"\n    claude_messages = _claude_translate_messages(messages)\n    claude_tools = _claude_translate_tools(tools)\n\n    system_content: str | None = None\n    for msg in messages:\n        if msg[\"role\"] == \"system\":\n            system_content = msg[\"content\"]\n            break\n\n    body: dict[str, Any] = {\n        \"model\": model,\n        \"messages\": claude_messages,\n        \"tools\": claude_tools,\n        \"max_tokens\": 16384,\n        \"temperature\": 0.0,\n    }\n    if system_content is not None:\n        body[\"system\"] = system_content\n\n    return \"/v1/messages\", body\n\n\ndef _claude_parse_response(data: dict[str, Any]) -> ParsedResponse:\n    \"\"\"Parse Claude Messages response into common format.\"\"\"\n    stop_reason = data.get(\"stop_reason\", \"\")\n    content_blocks = data.get(\"content\", [])\n\n    if stop_reason == \"tool_use\":\n        finish_reason = \"tool_calls\"\n    elif stop_reason == \"end_turn\":\n        finish_reason = \"stop\"\n    else:\n        finish_reason = stop_reason\n\n    tool_call_info: dict[str, str] | None = None\n    text_parts: list[str] = []\n    has_tool_call = False\n\n    for block in content_blocks:\n        block_type = block.get(\"type\")\n        if block_type == \"tool_use\":\n            has_tool_call = True\n            if tool_call_info is None:\n                input_data = block.get(\"input\", {})\n                tool_call_info = {\n                    \"id\": block.get(\"id\", \"call_0\"),\n                    \"name\": block.get(\"name\", \"\"),\n                    \"arguments\": json.dumps(input_data)\n                    if isinstance(input_data, dict)\n                    else str(input_data),\n                }\n        elif block_type == \"text\":\n            text = block.get(\"text\", \"\")\n            if text.strip():\n                text_parts.append(text)\n\n    content = \"\\n\".join(text_parts) if text_parts else None\n\n    return ParsedResponse(\n        finish_reason=finish_reason,\n        has_tool_call=has_tool_call,\n        tool_call=tool_call_info,\n        content=content,\n    )\n\n\ndef _claude_build_followup(\n    messages: list[dict[str, Any]],\n    tools: list[dict[str, Any]],\n    model: str,\n    parsed: ParsedResponse,\n    tool_result: str,\n) -> tuple[str, dict[str, Any]]:\n    \"\"\"Build multi-turn follow-up for Claude Messages.\"\"\"\n    assert parsed.tool_call is not None\n    tc = parsed.tool_call\n\n    try:\n        args_dict = json.loads(tc[\"arguments\"])\n    except (json.JSONDecodeError, TypeError):\n        args_dict = {}\n\n    claude_messages = _claude_translate_messages(messages)\n\n    claude_messages.append(\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                {\n                    \"type\": \"tool_use\",\n                    \"id\": tc[\"id\"],\n                    \"name\": tc[\"name\"],\n                    \"input\": args_dict,\n                }\n            ],\n        }\n    )\n\n    claude_messages.append(\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\n                    \"type\": \"tool_result\",\n                    \"tool_use_id\": tc[\"id\"],\n                    \"content\": tool_result,\n                }\n            ],\n        }\n    )\n\n    claude_tools = _claude_translate_tools(tools)\n\n    system_content: str | None = None\n    for msg in messages:\n        if msg[\"role\"] == \"system\":\n            system_content = msg[\"content\"]\n            break\n\n    body: dict[str, Any] = {\n        \"model\": model,\n        \"messages\": claude_messages,\n        \"tools\": claude_tools,\n        \"max_tokens\": 16384,\n        \"temperature\": 0.0,\n    }\n    if system_content is not None:\n        body[\"system\"] = system_content\n\n    return \"/v1/messages\", body\n\n\ndef _responses_translate_input(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:\n    \"\"\"Translate OpenAI chat messages to Responses API input items.\"\"\"\n    items: list[dict[str, Any]] = []\n\n    for msg in messages:\n        role = msg[\"role\"]\n\n        if role in (\"user\", \"system\"):\n            items.append(\n                {\n                    \"type\": \"message\",\n                    \"role\": role,\n                    \"content\": msg[\"content\"],\n                }\n            )\n        elif role == \"assistant\":\n            text_content = msg.get(\"content\")\n            if text_content and isinstance(text_content, str) and text_content.strip():\n                items.append(\n                    {\n                        \"type\": \"message\",\n                        \"role\": \"assistant\",\n                        \"content\": text_content,\n                    }\n                )\n            tool_calls = msg.get(\"tool_calls\")\n            if tool_calls:\n                for tc in tool_calls:\n                    fn = tc.get(\"function\", {})\n                    items.append(\n                        {\n                            \"type\": \"function_call\",\n                            \"call_id\": tc.get(\"id\", \"call_0\"),\n                            \"name\": fn.get(\"name\", \"\"),\n                            \"arguments\": fn.get(\"arguments\", \"{}\"),\n                        }\n                    )\n        elif role == \"tool\":\n            items.append(\n                {\n                    \"type\": \"function_call_output\",\n                    \"call_id\": msg.get(\"tool_call_id\", \"call_0\"),\n                    \"output\": msg.get(\"content\", \"\"),\n                }\n            )\n\n    return items\n\n\ndef _responses_build_request(\n    model: str,\n    messages: list[dict[str, Any]],\n    tools: list[dict[str, Any]],\n) -> tuple[str, dict[str, Any]]:\n    \"\"\"Build request for /v1/responses.\"\"\"\n    input_items = _responses_translate_input(messages)\n\n    body: dict[str, Any] = {\n        \"model\": model,\n        \"input\": input_items,\n        \"tools\": tools,\n        \"temperature\": 0.0,\n        \"max_output_tokens\": 4096,\n    }\n    return \"/v1/responses\", body\n\n\ndef _responses_parse_response(data: dict[str, Any]) -> ParsedResponse:\n    \"\"\"Parse OpenAI Responses API response into common format.\"\"\"\n    output = data.get(\"output\", [])\n\n    tool_call_info: dict[str, str] | None = None\n    text_parts: list[str] = []\n    has_tool_call = False\n\n    for item in output:\n        item_type = item.get(\"type\")\n        if item_type == \"function_call\":\n            has_tool_call = True\n            if tool_call_info is None:\n                tool_call_info = {\n                    \"id\": item.get(\"call_id\", \"call_0\"),\n                    \"name\": item.get(\"name\", \"\"),\n                    \"arguments\": item.get(\"arguments\", \"{}\"),\n                }\n        elif item_type == \"message\":\n            msg_content = item.get(\"content\", [])\n            if isinstance(msg_content, list):\n                for block in msg_content:\n                    if isinstance(block, dict):\n                        text = block.get(\"text\", \"\")\n                        if text and text.strip():\n                            text_parts.append(text)\n            elif isinstance(msg_content, str) and msg_content.strip():\n                text_parts.append(msg_content)\n\n    content = \"\\n\".join(text_parts) if text_parts else None\n\n    if has_tool_call:\n        finish_reason = \"tool_calls\"\n    else:\n        status = data.get(\"status\", \"completed\")\n        finish_reason = \"stop\" if status == \"completed\" else status\n\n    return ParsedResponse(\n        finish_reason=finish_reason,\n        has_tool_call=has_tool_call,\n        tool_call=tool_call_info,\n        content=content,\n    )\n\n\ndef _responses_build_followup(\n    messages: list[dict[str, Any]],\n    tools: list[dict[str, Any]],\n    model: str,\n    parsed: ParsedResponse,\n    tool_result: str,\n) -> tuple[str, dict[str, Any]]:\n    \"\"\"Build multi-turn follow-up for Responses API.\"\"\"\n    assert parsed.tool_call is not None\n    tc = parsed.tool_call\n\n    input_items = _responses_translate_input(messages)\n\n    input_items.append(\n        {\n            \"type\": \"function_call\",\n            \"call_id\": tc[\"id\"],\n            \"name\": tc[\"name\"],\n            \"arguments\": tc[\"arguments\"],\n        }\n    )\n\n    input_items.append(\n        {\n            \"type\": \"function_call_output\",\n            \"call_id\": tc[\"id\"],\n            \"output\": tool_result,\n        }\n    )\n\n    body: dict[str, Any] = {\n        \"model\": model,\n        \"input\": input_items,\n        \"tools\": tools,\n        \"temperature\": 0.0,\n        \"max_output_tokens\": 4096,\n    }\n    return \"/v1/responses\", body\n\n\nADAPTERS: dict[ApiName, dict[str, Any]] = {\n    \"openai\": {\n        \"build_request\": _openai_build_request,\n        \"parse_response\": _openai_parse_response,\n        \"build_followup\": _openai_build_followup,\n    },\n    \"claude\": {\n        \"build_request\": _claude_build_request,\n        \"parse_response\": _claude_parse_response,\n        \"build_followup\": _claude_build_followup,\n    },\n    \"responses\": {\n        \"build_request\": _responses_build_request,\n        \"parse_response\": _responses_parse_response,\n        \"build_followup\": _responses_build_followup,\n    },\n}\n\n\ndef run_scenario(\n    client: httpx.Client,\n    host: str,\n    port: int,\n    model: str,\n    scenario: Scenario,\n    api_name: ApiName,\n    timeout: float,\n    verbose: bool,\n) -> list[ScenarioResult]:\n    \"\"\"Run a single scenario against one API adapter. Returns 1-2 results.\"\"\"\n    adapter = ADAPTERS[api_name]\n    build_request = adapter[\"build_request\"]\n    parse_response = adapter[\"parse_response\"]\n    build_followup = adapter[\"build_followup\"]\n    results: list[ScenarioResult] = []\n\n    # --- Phase 1: initial request ---\n    path, body = build_request(model, scenario.messages, scenario.tools)\n\n    if verbose:\n        print(\n            f\"    [{api_name}] request: {path} {json.dumps(body, indent=2)}\",\n            file=sys.stderr,\n        )\n\n    try:\n        data, latency = call_api(client, host, port, path, body, timeout)\n    except Exception as exc:\n        results.append(\n            ScenarioResult(\n                name=scenario.name,\n                api=api_name,\n                phase=\"tool_call\",\n                passed=False,\n                error=f\"API error: {exc}\",\n            )\n        )\n        return results\n\n    if verbose:\n        print(\n            f\"    [{api_name}] response: {json.dumps(data, indent=2)}\", file=sys.stderr\n        )\n\n    parsed = parse_response(data)\n    checks: dict[str, bool] = {}\n\n    if scenario.expect_tool_call:\n        checks[\"finish_reason_tool_calls\"] = parsed.finish_reason == \"tool_calls\"\n        checks[\"has_tool_call\"] = parsed.has_tool_call\n\n        args_err: str | None = None\n        if parsed.has_tool_call and parsed.tool_call is not None:\n            checks[\"correct_function\"] = (\n                scenario.expected_function is None\n                or parsed.tool_call[\"name\"] == scenario.expected_function\n            )\n            if scenario.required_arg_keys:\n                ok, args_err = validate_args(\n                    parsed.tool_call[\"arguments\"], scenario.required_arg_keys\n                )\n                checks[\"valid_arguments\"] = ok\n            else:\n                checks[\"valid_arguments\"] = True\n            if scenario.nested_array_key and scenario.required_item_keys:\n                ok, nested_err = validate_nested_args(\n                    parsed.tool_call[\"arguments\"],\n                    scenario.nested_array_key,\n                    scenario.required_item_keys,\n                )\n                checks[\"valid_nested_structure\"] = ok\n                if not ok:\n                    args_err = nested_err\n        else:\n            checks[\"correct_function\"] = False\n            checks[\"valid_arguments\"] = False\n            args_err = \"No tool call returned\"\n\n        passed = all(checks.values())\n        error = args_err if not passed else None\n    else:\n        checks[\"finish_reason_stop\"] = parsed.finish_reason == \"stop\"\n        checks[\"no_tool_call\"] = not parsed.has_tool_call\n        checks[\"has_content\"] = (\n            parsed.content is not None and len(parsed.content.strip()) > 0\n        )\n        passed = all(checks.values())\n        error = (\n            None\n            if passed\n            else (\n                f\"finish_reason={parsed.finish_reason}, \"\n                f\"tool_call={'yes' if parsed.has_tool_call else 'no'}, \"\n                f\"content={'yes' if parsed.content else 'no'}\"\n            )\n        )\n\n    results.append(\n        ScenarioResult(\n            name=scenario.name,\n            api=api_name,\n            phase=\"tool_call\",\n            passed=passed,\n            checks=checks,\n            error=error,\n            latency_ms=latency,\n        )\n    )\n\n    # --- Phase 2: multi-turn follow-up ---\n    if (\n        scenario.tool_result is not None\n        and parsed.has_tool_call\n        and parsed.tool_call is not None\n    ):\n        followup_path, followup_body = build_followup(\n            scenario.messages,\n            scenario.tools,\n            model,\n            parsed,\n            scenario.tool_result,\n        )\n\n        if verbose:\n            print(\n                f\"    [{api_name}] follow_up request: {followup_path} {json.dumps(followup_body, indent=2)}\",\n                file=sys.stderr,\n            )\n\n        try:\n            data2, latency2 = call_api(\n                client, host, port, followup_path, followup_body, timeout\n            )\n        except Exception as exc:\n            results.append(\n                ScenarioResult(\n                    name=scenario.name,\n                    api=api_name,\n                    phase=\"follow_up\",\n                    passed=False,\n                    error=f\"API error: {exc}\",\n                )\n            )\n            return results\n\n        if verbose:\n            print(\n                f\"    [{api_name}] follow_up response: {json.dumps(data2, indent=2)}\",\n                file=sys.stderr,\n            )\n\n        parsed2 = parse_response(data2)\n        checks2: dict[str, bool] = {}\n        checks2[\"finish_reason_stop\"] = parsed2.finish_reason == \"stop\"\n        checks2[\"no_tool_call\"] = not parsed2.has_tool_call\n        checks2[\"has_content\"] = (\n            parsed2.content is not None and len(parsed2.content.strip()) > 0\n        )\n\n        passed2 = all(checks2.values())\n        error2: str | None = None\n        if not passed2:\n            error2 = (\n                f\"finish_reason={parsed2.finish_reason}, \"\n                f\"tool_call={'yes' if parsed2.has_tool_call else 'no'}, \"\n                f\"content={'yes' if parsed2.content else 'no'}\"\n            )\n        results.append(\n            ScenarioResult(\n                name=scenario.name,\n                api=api_name,\n                phase=\"follow_up\",\n                passed=passed2,\n                checks=checks2,\n                error=error2,\n                latency_ms=latency2,\n            )\n        )\n\n    return results\n\n\ndef result_to_dict(result: ScenarioResult) -> dict[str, Any]:\n    \"\"\"Convert a ScenarioResult to a JSON-serializable dict.\"\"\"\n    return {\n        \"name\": result.name,\n        \"api\": result.api,\n        \"phase\": result.phase,\n        \"passed\": result.passed,\n        \"checks\": result.checks,\n        \"error\": result.error,\n        \"latency_ms\": round(result.latency_ms, 1),\n    }\n\n\n_MULTI_NODE_PRIORITY: dict[tuple[str, str], int] = {\n    (\"tensor\", \"jaccl\"): 0,\n    (\"pipeline\", \"jaccl\"): 2,\n    (\"pipeline\", \"ring\"): 3,\n    (\"tensor\", \"ring\"): 4,\n}\n_SINGLE_NODE_PRIORITY = 1\n\n\ndef _placement_sort_key(p: dict[str, Any]) -> tuple[int, int]:\n    sharding = p.get(\"sharding\", \"\").lower()\n    meta = p.get(\"instance_meta\", \"\").lower()\n    kind = (\n        \"tensor\" if \"tensor\" in sharding else \"pipeline\",\n        \"jaccl\" if \"jaccl\" in meta else \"ring\",\n    )\n    n_nodes = nodes_used_in_instance(p[\"instance\"])\n    if n_nodes == 1:\n        return (_SINGLE_NODE_PRIORITY, -n_nodes)\n    priority = _MULTI_NODE_PRIORITY.get(kind, 99)\n    return (priority, -n_nodes)\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(\n        description=\"Multi-API tool-calling eval for exo\",\n        formatter_class=argparse.RawDescriptionHelpFormatter,\n        epilog=\"\"\"\\\nExamples:\n  %(prog)s --model mlx-community/Qwen3-30B-A3B-4bit\n  %(prog)s --model my-model --api openai --repeat 3\n  %(prog)s --model my-model --api all --scenarios weather_simple calculator_multi_turn\n  %(prog)s --model my-model --stdout\n\"\"\",\n    )\n    add_common_instance_args(parser)\n    parser.add_argument(\n        \"--api\",\n        choices=[\"openai\", \"claude\", \"responses\", \"all\"],\n        default=\"all\",\n        help=\"Which API adapter(s) to test (default: all)\",\n    )\n    parser.add_argument(\n        \"--repeat\",\n        type=int,\n        default=1,\n        help=\"Repeat each scenario N times (default: 1)\",\n    )\n    parser.add_argument(\n        \"--scenarios\",\n        nargs=\"*\",\n        help=\"Run only these scenarios (by name)\",\n    )\n    parser.add_argument(\n        \"--verbose\",\n        action=\"store_true\",\n        help=\"Print full API responses to stderr\",\n    )\n    parser.add_argument(\n        \"--json-out\",\n        default=\"bench/eval_results.json\",\n        help=\"Write JSON results to file (default: bench/eval_results.json)\",\n    )\n    parser.add_argument(\n        \"--stdout\",\n        action=\"store_true\",\n        help=\"Write JSON results to stdout instead of file\",\n    )\n    args = parser.parse_args()\n\n    all_scenarios = load_scenarios(SCENARIOS_PATH)\n    if args.scenarios:\n        scenarios = [s for s in all_scenarios if s.name in args.scenarios]\n        if not scenarios:\n            print(\n                f\"No matching scenarios. Available: {[s.name for s in all_scenarios]}\",\n                file=sys.stderr,\n            )\n            sys.exit(1)\n    else:\n        scenarios = all_scenarios\n\n    api_names: list[ApiName] = (\n        [\"openai\", \"claude\", \"responses\"] if args.api == \"all\" else [args.api]\n    )\n\n    log = sys.stderr if args.stdout else sys.stdout\n    exo = ExoClient(args.host, args.port, timeout_s=args.timeout)\n    _short_id, full_model_id = resolve_model_short_id(exo, args.model)\n\n    selected = settle_and_fetch_placements(\n        exo, full_model_id, args, settle_timeout=args.settle_timeout\n    )\n    if not selected:\n        print(\"No valid placements matched your filters.\", file=sys.stderr)\n        sys.exit(1)\n\n    selected.sort(key=_placement_sort_key)\n    preview = selected[0]\n\n    settle_deadline = (\n        time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None\n    )\n\n    print(\"Planning phase: checking downloads...\", file=log)\n    run_planning_phase(\n        exo,\n        full_model_id,\n        preview,\n        args.danger_delete_downloads,\n        args.timeout,\n        settle_deadline,\n    )\n\n    instance = preview[\"instance\"]\n    instance_id = instance_id_from_instance(instance)\n    sharding = str(preview[\"sharding\"])\n    instance_meta = str(preview[\"instance_meta\"])\n    n_nodes = nodes_used_in_instance(instance)\n\n    print(f\"Model:     {full_model_id}\", file=log)\n    print(f\"Placement: {sharding} / {instance_meta} / {n_nodes} nodes\", file=log)\n    print(f\"Endpoint:  http://{args.host}:{args.port}\", file=log)\n    print(f\"APIs:      {', '.join(api_names)}\", file=log)\n\n    total_runs = len(scenarios) * args.repeat * len(api_names)\n    print(\n        f\"Scenarios: {len(scenarios)} x {args.repeat} repeats x {len(api_names)} APIs = {total_runs} runs\",\n        file=log,\n    )\n    print(\"=\" * 72, file=log)\n\n    exo.request_json(\"POST\", \"/instance\", body={\"instance\": instance})\n    try:\n        wait_for_instance_ready(exo, instance_id)\n    except (RuntimeError, TimeoutError) as e:\n        print(f\"Failed to initialize placement: {e}\", file=sys.stderr)\n        with contextlib.suppress(ExoHttpError):\n            exo.request_json(\"DELETE\", f\"/instance/{instance_id}\")\n        sys.exit(1)\n\n    time.sleep(1)\n    all_results: list[ScenarioResult] = []\n\n    try:\n        with httpx.Client() as http_client:\n            for run_idx in range(args.repeat):\n                if args.repeat > 1:\n                    print(f\"\\n--- Run {run_idx + 1}/{args.repeat} ---\", file=log)\n\n                for scenario in scenarios:\n                    for api_name in api_names:\n                        print(\n                            f\"\\n  [{api_name:>9}] {scenario.name}: {scenario.description}\",\n                            file=log,\n                        )\n\n                        scenario_results = run_scenario(\n                            http_client,\n                            args.host,\n                            args.port,\n                            full_model_id,\n                            scenario,\n                            api_name,\n                            args.timeout,\n                            args.verbose,\n                        )\n                        all_results.extend(scenario_results)\n\n                        for r in scenario_results:\n                            status = \"PASS\" if r.passed else \"FAIL\"\n                            print(\n                                f\"    [{r.phase:>10}] {status}  ({r.latency_ms:.0f}ms)\",\n                                file=log,\n                            )\n                            for check_name, check_ok in r.checks.items():\n                                mark = \"+\" if check_ok else \"-\"\n                                print(f\"      {mark} {check_name}\", file=log)\n                            if r.error:\n                                print(f\"      ! {r.error}\", file=log)\n    finally:\n        try:\n            exo.request_json(\"DELETE\", f\"/instance/{instance_id}\")\n        except ExoHttpError as e:\n            if e.status != 404:\n                raise\n        wait_for_instance_gone(exo, instance_id)\n\n    # --- Summary ---\n    print(f\"\\n{'=' * 72}\", file=log)\n\n    total = len(all_results)\n    passed = sum(1 for r in all_results if r.passed)\n\n    tool_call_results = [r for r in all_results if r.phase == \"tool_call\"]\n    follow_up_results = [r for r in all_results if r.phase == \"follow_up\"]\n    tc_passed = sum(1 for r in tool_call_results if r.passed)\n    fu_passed = sum(1 for r in follow_up_results if r.passed)\n    avg_latency = sum(r.latency_ms for r in all_results) / total if total else 0\n\n    print(\n        f\"Total:       {passed}/{total} passed ({100 * passed / total:.0f}%)\", file=log\n    )\n    print(f\"Tool call:   {tc_passed}/{len(tool_call_results)} passed\", file=log)\n    if follow_up_results:\n        print(f\"Follow-up:   {fu_passed}/{len(follow_up_results)} passed\", file=log)\n    print(f\"Avg latency: {avg_latency:.0f}ms\", file=log)\n\n    for api_name in api_names:\n        api_results = [r for r in all_results if r.api == api_name]\n        api_passed = sum(1 for r in api_results if r.passed)\n        print(f\"  {api_name:>9}: {api_passed}/{len(api_results)} passed\", file=log)\n\n    if passed < total:\n        print(\"\\nFailed:\", file=log)\n        for r in all_results:\n            if not r.passed:\n                print(f\"  - {r.name} [{r.api}/{r.phase}]: {r.error}\", file=log)\n\n    json_results = [result_to_dict(r) for r in all_results]\n\n    if args.stdout:\n        print(json.dumps(json_results, indent=2))\n    else:\n        json_path = args.json_out\n        parent = os.path.dirname(json_path)\n        if parent:\n            os.makedirs(parent, exist_ok=True)\n        with open(json_path, \"w\") as f:\n            json.dump(json_results, f, indent=2)\n            f.write(\"\\n\")\n        print(f\"\\nJSON results written to {json_path}\", file=log)\n\n    sys.exit(0 if passed == total else 1)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bench/exo_bench.py",
    "content": "# type: ignore\n#!/usr/bin/env python3\n\"\"\"Tool-calling eval for exo's OpenAI-compatible API.\n\nTests whether models correctly:\n- Trigger tool calls when appropriate\n- Return valid JSON arguments matching function schemas\n- Handle multi-turn tool use (call -> result -> final answer)\n- Avoid calling tools when unnecessary\n\nStart exo with a model first, then run:\n    uv run python tool_call_eval.py --model <model-id>\n    uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415\n    uv run python tool_call_eval.py --model <model-id> --repeat 3\n    uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport contextlib\nimport itertools\nimport json\nimport sys\nimport time\nfrom collections.abc import Callable\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\nfrom pathlib import Path\nfrom statistics import mean\nfrom typing import Any\n\nfrom harness import (\n    ExoClient,\n    ExoHttpError,\n    add_common_instance_args,\n    instance_id_from_instance,\n    nodes_used_in_instance,\n    resolve_model_short_id,\n    run_planning_phase,\n    settle_and_fetch_placements,\n    wait_for_instance_gone,\n    wait_for_instance_ready,\n)\nfrom loguru import logger\nfrom transformers import AutoTokenizer\n\n# Monkey-patch for transformers 5.x compatibility\n# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location\n# which was moved in transformers 5.0.0rc2\ntry:\n    import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization\n    from transformers.convert_slow_tokenizer import bytes_to_unicode\n\n    if not hasattr(gpt2_tokenization, \"bytes_to_unicode\"):\n        gpt2_tokenization.bytes_to_unicode = bytes_to_unicode  # type: ignore[attr-defined]\nexcept ImportError:\n    pass  # transformers < 5.0 or bytes_to_unicode not available\n\n\ndef load_tokenizer_for_bench(model_id: str) -> Any:\n    \"\"\"\n    Load tokenizer for benchmarking, with special handling for Kimi models.\n\n    Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer.\n    This function replicates the logic from utils_mlx.py for bench compatibility.\n    \"\"\"\n    model_id_lower = model_id.lower()\n\n    if \"kimi-k2\" in model_id_lower:\n        import importlib.util\n        import types\n\n        from huggingface_hub import snapshot_download\n\n        # Download/get the model path\n        model_path = Path(\n            snapshot_download(\n                model_id,\n                allow_patterns=[\"*.json\", \"*.py\", \"*.tiktoken\", \"*.model\"],\n            )\n        )\n\n        sys.path.insert(0, str(model_path))\n\n        # Load tool_declaration_ts first (tokenization_kimi imports it with relative import)\n        tool_decl_path = model_path / \"tool_declaration_ts.py\"\n        if tool_decl_path.exists():\n            spec = importlib.util.spec_from_file_location(\n                \"tool_declaration_ts\", tool_decl_path\n            )\n            if spec and spec.loader:\n                tool_decl_module = importlib.util.module_from_spec(spec)\n                sys.modules[\"tool_declaration_ts\"] = tool_decl_module\n                spec.loader.exec_module(tool_decl_module)\n\n        # Load tokenization_kimi with patched source (convert relative to absolute import)\n        tok_path = model_path / \"tokenization_kimi.py\"\n        source = tok_path.read_text()\n        source = source.replace(\"from .tool_declaration_ts\", \"from tool_declaration_ts\")\n        spec = importlib.util.spec_from_file_location(\"tokenization_kimi\", tok_path)\n        if spec:\n            tok_module = types.ModuleType(\"tokenization_kimi\")\n            tok_module.__file__ = str(tok_path)\n            sys.modules[\"tokenization_kimi\"] = tok_module\n            exec(compile(source, tok_path, \"exec\"), tok_module.__dict__)  # noqa: S102\n            TikTokenTokenizer = tok_module.TikTokenTokenizer  # noqa: N806\n        else:\n            from tokenization_kimi import TikTokenTokenizer  # type: ignore[import-not-found]  # noqa: I001\n\n        hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path)\n\n        # Patch encode to use internal tiktoken model directly\n        # transformers 5.x has a bug in the encode->pad path for slow tokenizers\n        def _patched_encode(text: str, **kwargs: object) -> list[int]:\n            # Pass allowed_special=\"all\" to handle special tokens like <|im_user|>\n            return list(hf_tokenizer.model.encode(text, allowed_special=\"all\"))\n\n        hf_tokenizer.encode = _patched_encode\n\n        return hf_tokenizer\n\n    # Default: use AutoTokenizer\n    return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n\n\ndef format_peak_memory(b: float) -> str:\n    for unit in [\"B\", \"KB\", \"MB\", \"GB\", \"TB\"]:\n        if b < 1024.0:\n            return f\"{b:.2f}{unit}\"\n        b /= 1024.0\n    raise ValueError(\"You're using petabytes of memory. Something went wrong...\")\n\n\ndef parse_int_list(values: list[str]) -> list[int]:\n    items: list[int] = []\n    for v in values:\n        for part in v.split(\",\"):\n            part = part.strip()\n            if part:\n                items.append(int(part))\n    return items\n\n\ndef run_one_completion(\n    client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer\n) -> tuple[dict[str, Any], int]:\n    content, pp_tokens = prompt_sizer.build(pp_hint)\n    payload: dict[str, Any] = {\n        \"model\": model_id,\n        \"messages\": [{\"role\": \"user\", \"content\": content}],\n        \"stream\": False,\n        \"max_tokens\": tg,\n    }\n\n    t0 = time.perf_counter()\n    out = client.post_bench_chat_completions(payload)\n    elapsed = time.perf_counter() - t0\n\n    stats = out.get(\"generation_stats\")\n\n    # Extract preview, handling None content (common for thinking models)\n    choices = out.get(\"choices\") or [{}]\n    message = choices[0].get(\"message\", {}) if choices else {}\n    content = message.get(\"content\") or \"\"\n    preview = content[:200] if content else \"\"\n\n    return {\n        \"elapsed_s\": elapsed,\n        \"output_text_preview\": preview,\n        \"stats\": stats,\n    }, pp_tokens\n\n\nclass PromptSizer:\n    def __init__(self, tokenizer: Any, atom: str = \"a \"):\n        self.tokenizer = tokenizer\n        self.atom = atom\n        self.count_fn = PromptSizer._make_counter(tokenizer)\n        self.base_tokens = self.count_fn(\"\")\n\n    @staticmethod\n    def _make_counter(tokenizer: Any) -> Callable[[str], int]:\n        def count_fn(user_content: str) -> int:\n            messages = [{\"role\": \"user\", \"content\": user_content}]\n            ids = tokenizer.apply_chat_template(\n                messages, tokenize=True, add_generation_prompt=True\n            )\n            # Fix for transformers 5.x\n            if hasattr(ids, \"input_ids\"):\n                ids = ids.input_ids\n            return int(len(ids))\n\n        return count_fn\n\n    def build(self, target_prompt_tokens: int) -> tuple[str, int]:\n        target = int(target_prompt_tokens)\n        if target < self.base_tokens:\n            raise RuntimeError(\n                f\"Target ({target}) is smaller than template overhead ({self.base_tokens}).\"\n            )\n\n        # Estimate tokens per atom using a sample\n        sample_count = 100\n        sample_content = self.atom * sample_count\n        sample_tokens = self.count_fn(sample_content) - self.base_tokens\n        tokens_per_atom = sample_tokens / sample_count\n\n        # Estimate starting point\n        needed_tokens = target - self.base_tokens\n        estimated_atoms = int(needed_tokens / tokens_per_atom)\n\n        # Binary search to find exact atom count\n        low, high = 0, estimated_atoms * 2 + 100\n        while low < high:\n            mid = (low + high) // 2\n            tok = self.count_fn(self.atom * mid)\n            if tok < target:\n                low = mid + 1\n            else:\n                high = mid\n\n        content = self.atom * low\n        tok = self.count_fn(content)\n        logger.info(f\"{tok=}\")\n\n        if tok != target:\n            raise RuntimeError(\n                f\"Overshot: got {tok} tokens (target {target}). \"\n                f\"Pick a different atom (try ' a' or '\\\\n' or '0 ').\"\n            )\n\n        return content, tok\n\n\ndef main() -> int:\n    ap = argparse.ArgumentParser(\n        prog=\"exo-bench\",\n        description=\"Benchmark exo model throughput across placement previews.\",\n    )\n    add_common_instance_args(ap)\n    ap.add_argument(\n        \"--pp\",\n        nargs=\"+\",\n        required=True,\n        help=\"Prompt-size hints (ints). Accepts commas.\",\n    )\n    ap.add_argument(\n        \"--tg\",\n        nargs=\"+\",\n        required=True,\n        help=\"Generation lengths (ints). Accepts commas.\",\n    )\n    ap.add_argument(\n        \"--repeat\", type=int, default=1, help=\"Repetitions per (pp,tg) pair.\"\n    )\n    ap.add_argument(\n        \"--concurrency\",\n        nargs=\"+\",\n        default=[\"1\"],\n        help=\"Concurrency levels (ints). Accepts commas. E.g. --concurrency 1,2,4,8. Default 1.\",\n    )\n    ap.add_argument(\n        \"--warmup\",\n        type=int,\n        default=0,\n        help=\"Warmup runs per placement (uses first pp/tg).\",\n    )\n    ap.add_argument(\n        \"--json-out\",\n        default=\"bench/results.json\",\n        help=\"Write raw per-run results JSON to this path.\",\n    )\n    ap.add_argument(\"--stdout\", action=\"store_true\", help=\"Write results to stdout\")\n    ap.add_argument(\n        \"--dry-run\", action=\"store_true\", help=\"List selected placements and exit.\"\n    )\n    ap.add_argument(\n        \"--all-combinations\",\n        action=\"store_true\",\n        help=\"Force all pp×tg combinations (cartesian product) even when lists have equal length.\",\n    )\n    args = ap.parse_args()\n\n    pp_list = parse_int_list(args.pp)\n    tg_list = parse_int_list(args.tg)\n    if not pp_list or not tg_list:\n        logger.error(\"pp and tg lists must be non-empty\")\n        return 2\n    if args.repeat <= 0:\n        logger.error(\"--repeat must be >= 1\")\n        return 2\n    concurrency_list = parse_int_list(args.concurrency)\n    if not concurrency_list or any(c <= 0 for c in concurrency_list):\n        logger.error(\"--concurrency values must be >= 1\")\n        return 2\n\n    # Log pairing mode\n    use_combinations = args.all_combinations or len(pp_list) != len(tg_list)\n    if use_combinations:\n        logger.info(\n            f\"pp/tg mode: combinations (product) - {len(pp_list) * len(tg_list)} pairs\"\n        )\n    else:\n        logger.info(f\"pp/tg mode: tandem (zip) - {len(pp_list)} pairs\")\n\n    client = ExoClient(args.host, args.port, timeout_s=args.timeout)\n    short_id, full_model_id = resolve_model_short_id(\n        client, args.model, force_download=args.force_download\n    )\n\n    tokenizer = load_tokenizer_for_bench(full_model_id)\n    if tokenizer is None:\n        raise RuntimeError(\"[exo-bench] tokenizer load failed\")\n\n    try:\n        prompt_sizer = PromptSizer(tokenizer)\n        logger.debug(f\"[exo-bench] loaded tokenizer: {full_model_id} for prompt sizer\")\n    except Exception:\n        logger.error(\"[exo-bench] tokenizer usable but prompt sizing failed\")\n        raise\n\n    selected = settle_and_fetch_placements(\n        client, full_model_id, args, settle_timeout=args.settle_timeout\n    )\n\n    if not selected:\n        logger.error(\"No valid placements matched your filters.\")\n        return 1\n\n    selected.sort(\n        key=lambda p: (\n            str(p.get(\"instance_meta\", \"\")),\n            str(p.get(\"sharding\", \"\")),\n            -nodes_used_in_instance(p[\"instance\"]),\n        ),\n        reverse=True,\n    )\n\n    logger.debug(f\"exo-bench model: short_id={short_id} full_id={full_model_id}\")\n    logger.info(f\"placements: {len(selected)}\")\n    for p in selected:\n        logger.info(\n            f\"  - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}\"\n        )\n\n    if args.dry_run:\n        return 0\n\n    settle_deadline = (\n        time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None\n    )\n\n    logger.info(\"Planning phase: checking downloads...\")\n    download_duration_s = run_planning_phase(\n        client,\n        full_model_id,\n        selected[0],\n        args.danger_delete_downloads,\n        args.timeout,\n        settle_deadline,\n    )\n    if download_duration_s is not None:\n        logger.info(f\"Download: {download_duration_s:.1f}s (freshly downloaded)\")\n    else:\n        logger.info(\"Download: model already cached\")\n\n    all_rows: list[dict[str, Any]] = []\n\n    for preview in selected:\n        instance = preview[\"instance\"]\n        instance_id = instance_id_from_instance(instance)\n\n        sharding = str(preview[\"sharding\"])\n        instance_meta = str(preview[\"instance_meta\"])\n        n_nodes = nodes_used_in_instance(instance)\n\n        logger.info(\"=\" * 80)\n        logger.info(\n            f\"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}\"\n        )\n\n        client.request_json(\"POST\", \"/instance\", body={\"instance\": instance})\n        try:\n            wait_for_instance_ready(client, instance_id)\n        except (RuntimeError, TimeoutError) as e:\n            logger.error(f\"Failed to initialize placement: {e}\")\n            with contextlib.suppress(ExoHttpError):\n                client.request_json(\"DELETE\", f\"/instance/{instance_id}\")\n            continue\n\n        time.sleep(1)\n\n        try:\n            for i in range(args.warmup):\n                run_one_completion(\n                    client, full_model_id, pp_list[0], tg_list[0], prompt_sizer\n                )\n                logger.debug(f\"  warmup {i + 1}/{args.warmup} done\")\n\n            # If pp and tg lists have same length, run in tandem (zip)\n            # Otherwise (or if --all-combinations), run all combinations (cartesian product)\n            if use_combinations:\n                pp_tg_pairs = list(itertools.product(pp_list, tg_list))\n            else:\n                pp_tg_pairs = list(zip(pp_list, tg_list, strict=True))\n\n            for pp, tg in pp_tg_pairs:\n                for concurrency in concurrency_list:\n                    logger.info(f\"--- pp={pp} tg={tg} concurrency={concurrency} ---\")\n                    runs: list[dict[str, Any]] = []\n                    for r in range(args.repeat):\n                        time.sleep(3)\n\n                        if concurrency <= 1:\n                            # Sequential: single request\n                            try:\n                                row, actual_pp_tokens = run_one_completion(\n                                    client, full_model_id, pp, tg, prompt_sizer\n                                )\n                            except Exception as e:\n                                logger.error(e)\n                                continue\n                            row.update(\n                                {\n                                    \"model_short_id\": short_id,\n                                    \"model_id\": full_model_id,\n                                    \"placement_sharding\": sharding,\n                                    \"placement_instance_meta\": instance_meta,\n                                    \"placement_nodes\": n_nodes,\n                                    \"instance_id\": instance_id,\n                                    \"pp_tokens\": actual_pp_tokens,\n                                    \"tg\": tg,\n                                    \"repeat_index\": r,\n                                    \"concurrency\": 1,\n                                    **(\n                                        {\"download_duration_s\": download_duration_s}\n                                        if download_duration_s is not None\n                                        else {}\n                                    ),\n                                }\n                            )\n                            runs.append(row)\n                            all_rows.append(row)\n                        else:\n                            # Concurrent: fire N requests in parallel\n                            # Each thread gets its own ExoClient (separate HTTP connection)\n                            batch_results: list[tuple[dict[str, Any], int]] = []\n                            batch_errors = 0\n\n                            def _run_concurrent(\n                                idx: int, *, _pp: int = pp, _tg: int = tg\n                            ) -> tuple[dict[str, Any], int]:\n                                c = ExoClient(\n                                    args.host, args.port, timeout_s=args.timeout\n                                )\n                                return run_one_completion(\n                                    c, full_model_id, _pp, _tg, prompt_sizer\n                                )\n\n                            with ThreadPoolExecutor(max_workers=concurrency) as pool:\n                                futures = {\n                                    pool.submit(_run_concurrent, i): i\n                                    for i in range(concurrency)\n                                }\n                                for fut in as_completed(futures):\n                                    try:\n                                        batch_results.append(fut.result())\n                                    except Exception as e:\n                                        logger.error(f\"Concurrent request failed: {e}\")\n                                        batch_errors += 1\n\n                            for idx, (row, actual_pp_tokens) in enumerate(\n                                batch_results\n                            ):\n                                row.update(\n                                    {\n                                        \"model_short_id\": short_id,\n                                        \"model_id\": full_model_id,\n                                        \"placement_sharding\": sharding,\n                                        \"placement_instance_meta\": instance_meta,\n                                        \"placement_nodes\": n_nodes,\n                                        \"instance_id\": instance_id,\n                                        \"pp_tokens\": actual_pp_tokens,\n                                        \"tg\": tg,\n                                        \"repeat_index\": r,\n                                        \"concurrency\": concurrency,\n                                        \"concurrent_index\": idx,\n                                        **(\n                                            {\"download_duration_s\": download_duration_s}\n                                            if download_duration_s is not None\n                                            else {}\n                                        ),\n                                    }\n                                )\n                                runs.append(row)\n                                all_rows.append(row)\n\n                            if batch_results:\n                                valid_gen_tps = [\n                                    x[\"stats\"][\"generation_tps\"]\n                                    for x, _ in batch_results\n                                    if x[\"stats\"][\"generation_tps\"] > 0\n                                ]\n                                agg_gen_tps = (\n                                    mean(valid_gen_tps) if valid_gen_tps else 0.0\n                                )\n                                gen_tps = agg_gen_tps / concurrency\n                                logger.info(\n                                    f\"[concurrent {concurrency}x]  \"\n                                    f\"agg_gen_tps={agg_gen_tps:.2f}  \"\n                                    f\"gen_tps={gen_tps:.2f}  \"\n                                    f\"errors={batch_errors}\"\n                                )\n\n                    if runs:\n                        prompt_tps = mean(x[\"stats\"][\"prompt_tps\"] for x in runs)\n                        gen_tps = mean(\n                            x[\"stats\"][\"generation_tps\"] / x[\"concurrency\"]\n                            for x in runs\n                        )\n                        ptok = mean(x[\"stats\"][\"prompt_tokens\"] for x in runs)\n                        gtok = mean(x[\"stats\"][\"generation_tokens\"] for x in runs)\n                        peak = mean(\n                            x[\"stats\"][\"peak_memory_usage\"][\"inBytes\"] for x in runs\n                        )\n\n                        logger.info(\n                            f\"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f}    \"\n                            f\"prompt_tokens={ptok} gen_tokens={gtok}    \"\n                            f\"peak_memory={format_peak_memory(peak)}\\n\"\n                        )\n                    time.sleep(2)\n        finally:\n            try:\n                client.request_json(\"DELETE\", f\"/instance/{instance_id}\")\n            except ExoHttpError as e:\n                if e.status != 404:\n                    raise\n            wait_for_instance_gone(client, instance_id)\n            logger.debug(f\"Deleted instance {instance_id}\")\n\n            time.sleep(5)\n\n    if args.stdout:\n        json.dump(all_rows, sys.stdout, indent=2, ensure_ascii=False)\n    elif args.json_out:\n        with open(args.json_out, \"w\", encoding=\"utf-8\") as f:\n            json.dump(all_rows, f, indent=2, ensure_ascii=False)\n        logger.debug(f\"\\nWrote results JSON: {args.json_out}\")\n\n    return 0\n\n\nif __name__ == \"__main__\":\n    raise SystemExit(main())\n"
  },
  {
    "path": "bench/exo_eval.py",
    "content": "# type: ignore\n#!/usr/bin/env python3\n\"\"\"Quality evaluation for exo — matches Artificial Analysis methodology.\n\nRuns LLM benchmarks against exo's OpenAI-compatible API using the same\nprompts, temperature settings, and answer extraction as Artificial Analysis.\n\nSupported benchmarks:\n  gpqa_diamond   - Graduate-level science QA (198 questions, 4-choice MC)\n  mmlu_pro       - Multi-task language understanding (12K questions, 10-choice MC)\n  aime_2024      - Math olympiad 2024 (30 problems, integer answers)\n  aime_2025      - Math olympiad 2025 (30 problems, integer answers)\n  humaneval      - Python code generation (164 problems, pass@1)\n  livecodebench  - Competitive programming (880+ problems, pass@1)\n\nModel configs in eval_configs/models.toml auto-detect reasoning/non-reasoning\nsettings per model. Override with --reasoning / --no-reasoning.\n\nUsage:\n  uv run python exo_eval.py --model <model-id> --tasks gpqa_diamond\n  uv run python exo_eval.py --model <model-id> --tasks humaneval,livecodebench --limit 50\n  uv run python exo_eval.py --model <model-id> --tasks gpqa_diamond --compare-concurrency 1,4\n\nReferences:\n  https://artificialanalysis.ai/methodology/intelligence-benchmarking\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport asyncio\nimport contextlib\nimport json\nimport multiprocessing\nimport random\nimport re\nimport sys\nimport time\nimport tomllib\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any\n\nimport httpx\nfrom harness import (\n    ExoClient,\n    ExoHttpError,\n    add_common_instance_args,\n    instance_id_from_instance,\n    nodes_used_in_instance,\n    resolve_model_short_id,\n    run_planning_phase,\n    settle_and_fetch_placements,\n    wait_for_instance_gone,\n    wait_for_instance_ready,\n)\nfrom loguru import logger\n\n# ---------------------------------------------------------------------------\n# Artificial Analysis constants\n# ---------------------------------------------------------------------------\n\nMAX_RETRIES = 30\nDEFAULT_MAX_TOKENS = 16_384\nREASONING_MAX_TOKENS = 131_072\nTEMPERATURE_NON_REASONING = 0.0\nTEMPERATURE_REASONING = 1.0\n\n# MC answer extraction: 8 fallback regex patterns.\n# All patterns are tried; the match at the latest text position wins\n# (handles models that self-correct during reasoning).\n_MC_PATTERNS: list[re.Pattern[str]] = [\n    re.compile(\n        r\"(?i)[\\*\\_]{0,2}Answer[\\*\\_]{0,2}\\s*:[\\s\\*\\_]{0,2}\\s*([A-Z])(?![a-zA-Z0-9])\"\n    ),\n    re.compile(r\"\\\\boxed\\{[^}]*([A-Z])[^}]*\\}\"),\n    re.compile(r\"(?i)answer is ([a-zA-Z])\"),\n    re.compile(r\"(?i)answer is \\\\\\(([a-zA-Z])\"),\n    re.compile(r\"([A-Z])\\)\\s*[^A-Z]*$\"),\n    re.compile(r\"([A-Z])\\s+is\\s+the\\s+correct\\s+answer\"),\n    re.compile(r\"([A-Z])\\s*$\"),\n    re.compile(r\"([A-Z])\\s*\\.\"),\n]\n\n# Code extraction: last ```python ... ``` block (AA regex)\n_CODE_BLOCK_RE = re.compile(r\"```(?:python|Python)?\\s*\\n(.*?)```\", re.DOTALL)\n\n\n# ---------------------------------------------------------------------------\n# Model config loading\n# ---------------------------------------------------------------------------\n\n\ndef load_model_config(model_id: str) -> dict[str, Any] | None:\n    \"\"\"Look up model in eval_configs/models.toml. Returns config dict or None.\"\"\"\n    config_path = Path(__file__).resolve().parent / \"eval_configs\" / \"models.toml\"\n    if not config_path.exists():\n        return None\n    with open(config_path, \"rb\") as f:\n        data = tomllib.load(f)\n    for entry in data.get(\"model\", []):\n        patterns = entry.get(\"patterns\", [])\n        if any(p in model_id for p in patterns):\n            return entry\n    return None\n\n\n# ---------------------------------------------------------------------------\n# Answer extraction\n# ---------------------------------------------------------------------------\n\n\ndef extract_mc_answer(text: str, valid_letters: str = \"ABCD\") -> str | None:\n    \"\"\"Extract MC answer. Last match by text position wins.\"\"\"\n    valid_set = set(valid_letters)\n    best: tuple[int, str] | None = None\n    for pattern in _MC_PATTERNS:\n        for m in pattern.finditer(text):\n            letter = m.group(1).upper()\n            if letter in valid_set:\n                pos = m.start()\n                if best is None or pos >= best[0]:\n                    best = (pos, letter)\n    return best[1] if best else None\n\n\ndef extract_boxed_answer(text: str) -> str | None:\n    r\"\"\"Extract content from the last \\boxed{...}.\"\"\"\n    matches: list[str] = []\n    idx = 0\n    while True:\n        pos = text.find(\"\\\\boxed{\", idx)\n        if pos < 0:\n            break\n        depth = 0\n        i = pos + len(\"\\\\boxed{\")\n        start = i\n        while i < len(text):\n            if text[i] == \"{\":\n                depth += 1\n            elif text[i] == \"}\":\n                if depth == 0:\n                    matches.append(text[start:i])\n                    break\n                depth -= 1\n            i += 1\n        idx = i + 1 if i < len(text) else len(text)\n    return matches[-1].strip() if matches else None\n\n\ndef extract_code_block(text: str, preserve_indent: bool = False) -> str | None:\n    \"\"\"Extract the last Python code block from markdown response.\n\n    If preserve_indent is True, only strip trailing whitespace (keeps leading\n    indentation intact — needed for HumanEval function-body completions).\n    \"\"\"\n    matches = _CODE_BLOCK_RE.findall(text)\n    if matches:\n        raw = matches[-1]\n        return raw.rstrip() if preserve_indent else raw.strip()\n    # Fallback: try raw code after last ```\n    lines = text.split(\"\\n\")\n    backtick_lines = [i for i, line in enumerate(lines) if \"```\" in line]\n    if len(backtick_lines) >= 2:\n        return \"\\n\".join(lines[backtick_lines[-2] + 1 : backtick_lines[-1]])\n    return None\n\n\ndef check_aime_answer(extracted: str, gold: int) -> bool:\n    \"\"\"Check if extracted AIME answer matches gold integer.\"\"\"\n    try:\n        return int(extracted.strip()) == gold\n    except ValueError:\n        pass\n    try:\n        from math_verify import parse, verify\n\n        return verify(parse(str(gold)), parse(extracted))\n    except Exception:\n        return False\n\n\n# ---------------------------------------------------------------------------\n# Code execution — official evaluation harnesses\n# ---------------------------------------------------------------------------\n\n# LiveCodeBench: vendored from https://github.com/LiveCodeBench/LiveCodeBench\n# run_test() must execute in a child process because reliability_guard()\n# permanently disables OS functions (os.kill, subprocess.Popen, etc.).\n\n\ndef _lcb_worker(\n    sample: dict,\n    code: str,\n    timeout: int,\n    result_holder: list[Any],\n    metadata_holder: list[Any],\n) -> None:\n    \"\"\"Target for multiprocessing.Process — runs vendored LCB run_test.\"\"\"\n    from vendor.lcb_testing_util import run_test\n\n    try:\n        results, metadata = run_test(sample, test=code, debug=False, timeout=timeout)\n        result_holder.append(results)\n        metadata_holder.append(metadata)\n    except Exception as e:\n        result_holder.append([-4])\n        metadata_holder.append({\"error_code\": -4, \"error_message\": str(e)})\n\n\ndef run_livecodebench_test(\n    code: str,\n    sample: dict,\n    timeout: int = 6,\n) -> tuple[bool, str]:\n    \"\"\"Run LCB evaluation in a subprocess. Returns (passed, diagnostic_info).\"\"\"\n    manager = multiprocessing.Manager()\n    result_holder = manager.list()\n    metadata_holder = manager.list()\n\n    proc = multiprocessing.Process(\n        target=_lcb_worker,\n        args=(sample, code, timeout, result_holder, metadata_holder),\n    )\n    proc.start()\n\n    # Global timeout: (per-test timeout + 1) * num_tests + 5\n    num_tests = len(json.loads(sample[\"input_output\"]).get(\"inputs\", []))\n    global_timeout = (timeout + 1) * num_tests + 5\n    proc.join(timeout=global_timeout)\n\n    if proc.is_alive():\n        proc.kill()\n        proc.join()\n        return False, \"Global timeout exceeded\"\n\n    if not result_holder:\n        return False, \"No results returned from worker\"\n\n    results = list(result_holder[0])\n    metadata = dict(metadata_holder[0]) if metadata_holder else {}\n\n    # LCB convention: True = pass, negative int = failure code\n    all_passed = all(r is True or r == 1 for r in results)\n    if all_passed:\n        return True, \"\"\n\n    diag = metadata.get(\"error_message\", \"\")\n    if not diag and \"output\" in metadata:\n        diag = f\"Got {metadata['output']}, expected {metadata.get('expected', '?')}\"\n    return False, diag\n\n\ndef run_humaneval_test(\n    problem: dict, completion: str, timeout: float = 10.0\n) -> tuple[bool, str]:\n    \"\"\"Run HumanEval evaluation using the official human_eval package.\"\"\"\n    from human_eval.execution import check_correctness\n\n    result = check_correctness(problem, completion, timeout)\n    passed = result[\"passed\"]\n    diag = \"\" if passed else result.get(\"result\", \"failed\")\n    return passed, diag\n\n\n# ---------------------------------------------------------------------------\n# Benchmark definitions\n# ---------------------------------------------------------------------------\n\n\n@dataclass\nclass QuestionResult:\n    question_id: int\n    prompt: str\n    response: str\n    extracted_answer: str | None\n    gold_answer: str\n    correct: bool\n    error: str | None = None\n    prompt_tokens: int = 0\n    completion_tokens: int = 0\n    reasoning_tokens: int = 0\n    elapsed_s: float = 0.0\n\n\n@dataclass\nclass BenchmarkConfig:\n    name: str\n    description: str\n    dataset_name: str\n    dataset_config: str | None\n    split: str\n    kind: str  # \"mc\", \"math\", \"code\"\n\n\nBENCHMARKS: dict[str, BenchmarkConfig] = {\n    \"gpqa_diamond\": BenchmarkConfig(\n        name=\"gpqa_diamond\",\n        description=\"Graduate-level science QA (198 Q, 4-choice MC)\",\n        dataset_name=\"Idavidrein/gpqa\",\n        dataset_config=\"gpqa_diamond\",\n        split=\"train\",\n        kind=\"mc\",\n    ),\n    \"mmlu_pro\": BenchmarkConfig(\n        name=\"mmlu_pro\",\n        description=\"Multi-task language understanding (12K Q, 10-choice MC)\",\n        dataset_name=\"TIGER-Lab/MMLU-Pro\",\n        dataset_config=None,\n        split=\"test\",\n        kind=\"mc\",\n    ),\n    \"aime_2024\": BenchmarkConfig(\n        name=\"aime_2024\",\n        description=\"Math olympiad 2024 (30 problems, integer answers)\",\n        dataset_name=\"HuggingFaceH4/aime_2024\",\n        dataset_config=None,\n        split=\"train\",\n        kind=\"math\",\n    ),\n    \"aime_2025\": BenchmarkConfig(\n        name=\"aime_2025\",\n        description=\"Math olympiad 2025 (30 problems, integer answers)\",\n        dataset_name=\"MathArena/aime_2025\",\n        dataset_config=None,\n        split=\"train\",\n        kind=\"math\",\n    ),\n    \"humaneval\": BenchmarkConfig(\n        name=\"humaneval\",\n        description=\"Python code generation (164 problems, pass@1)\",\n        dataset_name=\"openai/openai_humaneval\",\n        dataset_config=None,\n        split=\"test\",\n        kind=\"code\",\n    ),\n    \"livecodebench\": BenchmarkConfig(\n        name=\"livecodebench\",\n        description=\"Competitive programming (880+ problems, pass@1)\",\n        dataset_name=\"livecodebench/code_generation_lite\",\n        dataset_config=None,\n        split=\"test\",\n        kind=\"code\",\n    ),\n}\n\n\n# ---------------------------------------------------------------------------\n# Prompt formatters\n# ---------------------------------------------------------------------------\n\n_GPQA_INSTRUCTION = (\n    \"Answer the following multiple choice question. \"\n    \"The last line of your response should be in the following format: \"\n    \"'Answer: A/B/C/D' (e.g. 'Answer: A').\"\n)\n\n_MMLU_PRO_INSTRUCTION = (\n    \"Answer the following multiple choice question. \"\n    \"The last line of your response should be in the following format: \"\n    \"'Answer: A/B/C/D/E/F/G/H/I/J' (e.g. 'Answer: A').\"\n)\n\n_AIME_INSTRUCTION = (\n    \"Solve the following math problem step by step. \"\n    \"Put your answer inside \\\\boxed{}.\\n\"\n    \"Remember to put your answer inside \\\\boxed{}.\"\n)\n\n_HUMANEVAL_INSTRUCTION = (\n    \"Complete the following Python function. Return only the function body \"\n    \"inside a ```python code block. Do not include the function signature.\"\n)\n\n# LiveCodeBench: AA uses original prompts without custom system prompts\n_LCB_SYSTEM = (\n    \"You are an expert Python programmer. You will be given a question \"\n    \"(problem specification) and will generate a correct Python program \"\n    \"that matches the specification and passes all tests.\"\n)\n\n_LCB_WITH_STARTER = (\n    \"### Question:\\n{question}\\n\\n\"\n    \"### Format: You will use the following starter code to write the \"\n    \"solution to the problem and enclose your code within delimiters.\\n\"\n    \"```python\\n{starter_code}\\n```\\n\\n\"\n    \"### Answer: (use the provided format with backticks)\\n\"\n)\n\n_LCB_WITHOUT_STARTER = (\n    \"### Question:\\n{question}\\n\\n\"\n    \"### Format: Read the inputs from stdin solve the problem and write \"\n    \"the answer to stdout (do not directly test on the sample inputs). \"\n    \"Enclose your code within delimiters as follows. Ensure that when the \"\n    \"python program runs, it reads the inputs, runs the algorithm and \"\n    \"writes output to STDOUT.\\n\"\n    \"```python\\n# YOUR CODE HERE\\n```\\n\\n\"\n    \"### Answer: (use the provided format with backticks)\\n\"\n)\n\n\ndef format_gpqa_question(doc: dict, idx: int) -> tuple[str, str]:\n    \"\"\"Returns (prompt, correct_letter).\"\"\"\n    correct = doc[\"Correct Answer\"]\n    choices = [\n        correct,\n        doc[\"Incorrect Answer 1\"],\n        doc[\"Incorrect Answer 2\"],\n        doc[\"Incorrect Answer 3\"],\n    ]\n    rng = random.Random(idx)\n    order = rng.sample(range(4), 4)\n    shuffled = [choices[i] for i in order]\n    correct_letter = \"ABCD\"[order.index(0)]\n    choices_text = \"\\n\".join(f\"{L}) {shuffled[i]}\" for i, L in enumerate(\"ABCD\"))\n    return f\"{_GPQA_INSTRUCTION}\\n\\n{doc['Question']}\\n\\n{choices_text}\", correct_letter\n\n\ndef format_mmlu_pro_question(doc: dict) -> tuple[str, str]:\n    \"\"\"Returns (prompt, correct_letter).\"\"\"\n    options = doc[\"options\"]\n    letters = \"ABCDEFGHIJ\"\n    choices_text = \"\\n\".join(f\"{letters[i]}) {opt}\" for i, opt in enumerate(options))\n    return f\"{_MMLU_PRO_INSTRUCTION}\\n\\n{doc['question']}\\n\\n{choices_text}\", doc[\n        \"answer\"\n    ]\n\n\ndef format_aime_question(doc: dict) -> tuple[str, int]:\n    \"\"\"Returns (prompt, correct_answer_int).\"\"\"\n    return f\"{_AIME_INSTRUCTION}\\n\\n{doc['problem']}\", int(doc[\"answer\"])\n\n\ndef format_humaneval_question(doc: dict) -> tuple[str, dict]:\n    \"\"\"Returns (prompt, metadata_for_execution).\"\"\"\n    prompt = f\"{_HUMANEVAL_INSTRUCTION}\\n\\n```python\\n{doc['prompt']}```\"\n    # Pass the full problem dict — check_correctness needs task_id, prompt,\n    # test, entry_point\n    meta = {\n        \"problem\": {\n            \"task_id\": doc[\"task_id\"],\n            \"prompt\": doc[\"prompt\"],\n            \"test\": doc[\"test\"],\n            \"entry_point\": doc[\"entry_point\"],\n        },\n    }\n    return prompt, meta\n\n\ndef format_livecodebench_question(doc: dict) -> tuple[str, str | None, dict]:\n    \"\"\"Returns (prompt, system_message, metadata_for_execution).\"\"\"\n    starter_code = doc.get(\"starter_code\", \"\")\n    question_content = doc[\"question_content\"]\n\n    if starter_code and starter_code.strip():\n        user_msg = _LCB_WITH_STARTER.format(\n            question=question_content, starter_code=starter_code\n        )\n    else:\n        user_msg = _LCB_WITHOUT_STARTER.format(question=question_content)\n\n    # Parse test cases\n    public_tests = (\n        json.loads(doc[\"public_test_cases\"])\n        if isinstance(doc[\"public_test_cases\"], str)\n        else doc[\"public_test_cases\"]\n    )\n    private_tests = doc.get(\"private_test_cases\", \"[]\")\n    if isinstance(private_tests, str):\n        try:\n            private_tests = json.loads(private_tests)\n        except Exception:\n            import base64\n            import pickle\n            import zlib\n\n            private_tests = json.loads(\n                pickle.loads(\n                    zlib.decompress(base64.b64decode(private_tests.encode(\"utf-8\")))\n                )\n            )\n\n    all_tests = public_tests + (\n        private_tests if isinstance(private_tests, list) else []\n    )\n    test_inputs = [t[\"input\"] for t in all_tests]\n    test_outputs = [t[\"output\"] for t in all_tests]\n\n    metadata = doc.get(\"metadata\", \"{}\")\n    if isinstance(metadata, str):\n        metadata = json.loads(metadata)\n    func_name = metadata.get(\"func_name\")\n\n    # Build the sample dict in official LCB format for run_test()\n    input_output: dict[str, Any] = {\n        \"inputs\": test_inputs,\n        \"outputs\": test_outputs,\n    }\n    if func_name:\n        input_output[\"fn_name\"] = func_name\n\n    meta = {\n        \"sample\": {\"input_output\": json.dumps(input_output)},\n    }\n    return user_msg, _LCB_SYSTEM, meta\n\n\n# ---------------------------------------------------------------------------\n# API client with retries\n# ---------------------------------------------------------------------------\n\n\n@dataclass\nclass ApiResult:\n    content: str\n    prompt_tokens: int\n    completion_tokens: int\n    reasoning_tokens: int\n\n\nasync def _call_api(\n    client: httpx.AsyncClient,\n    base_url: str,\n    model: str,\n    prompt: str,\n    temperature: float,\n    max_tokens: int,\n    timeout: float | None,\n    system_message: str | None = None,\n    reasoning_effort: str | None = None,\n    top_p: float | None = None,\n) -> ApiResult:\n    messages = []\n    if system_message:\n        messages.append({\"role\": \"system\", \"content\": system_message})\n    messages.append({\"role\": \"user\", \"content\": prompt})\n\n    body: dict[str, Any] = {\n        \"model\": model,\n        \"messages\": messages,\n        \"temperature\": temperature,\n        \"max_tokens\": max_tokens,\n    }\n    if reasoning_effort is not None:\n        body[\"reasoning_effort\"] = reasoning_effort\n    if top_p is not None:\n        body[\"top_p\"] = top_p\n\n    resp = await client.post(\n        f\"{base_url}/v1/chat/completions\",\n        json=body,\n        timeout=timeout,\n    )\n    resp.raise_for_status()\n    data = resp.json()\n    content = data[\"choices\"][0][\"message\"][\"content\"]\n    if not content or not content.strip():\n        raise ValueError(\"Empty response from model\")\n    usage = data.get(\"usage\", {})\n    details = usage.get(\"completion_tokens_details\", {})\n    return ApiResult(\n        content=content,\n        prompt_tokens=usage.get(\"prompt_tokens\", 0),\n        completion_tokens=usage.get(\"completion_tokens\", 0),\n        reasoning_tokens=details.get(\"reasoning_tokens\", 0) if details else 0,\n    )\n\n\nasync def call_with_retries(\n    client: httpx.AsyncClient,\n    base_url: str,\n    model: str,\n    prompt: str,\n    temperature: float,\n    max_tokens: int,\n    timeout: float | None = None,\n    system_message: str | None = None,\n    reasoning_effort: str | None = None,\n    top_p: float | None = None,\n) -> ApiResult | None:\n    for attempt in range(MAX_RETRIES):\n        try:\n            return await _call_api(\n                client,\n                base_url,\n                model,\n                prompt,\n                temperature,\n                max_tokens,\n                timeout,\n                system_message,\n                reasoning_effort,\n                top_p,\n            )\n        except Exception as e:\n            if attempt < MAX_RETRIES - 1:\n                wait = min(2**attempt, 60)\n                logger.warning(\n                    f\"Attempt {attempt + 1}/{MAX_RETRIES} failed: {e}. Retrying in {wait}s...\"\n                )\n                await asyncio.sleep(wait)\n            else:\n                logger.error(f\"All {MAX_RETRIES} retries exhausted. Last error: {e}\")\n                return None\n\n\n# ---------------------------------------------------------------------------\n# Evaluation runners\n# ---------------------------------------------------------------------------\n\n\nasync def evaluate_benchmark(\n    benchmark_name: str,\n    base_url: str,\n    model: str,\n    temperature: float,\n    max_tokens: int,\n    concurrency: int = 1,\n    limit: int | None = None,\n    timeout: float | None = None,\n    reasoning_effort: str | None = None,\n    top_p: float | None = None,\n    difficulty: str | None = None,\n) -> list[QuestionResult]:\n    \"\"\"Run a benchmark. Returns per-question results.\"\"\"\n    import datasets\n\n    config = BENCHMARKS[benchmark_name]\n    logger.info(f\"Loading dataset {config.dataset_name}...\")\n\n    try:\n        if benchmark_name == \"livecodebench\":\n            ds = datasets.load_dataset(\n                \"json\",\n                data_files=\"hf://datasets/livecodebench/code_generation_lite/*.jsonl\",\n                split=\"train\",\n            )\n        else:\n            ds = datasets.load_dataset(\n                config.dataset_name,\n                config.dataset_config,\n                split=config.split,\n            )\n    except Exception as e:\n        logger.error(f\"Failed to load dataset: {e}\")\n        if \"gated\" in str(e).lower() or \"login\" in str(e).lower():\n            logger.error(\"Dataset requires authentication. Run: huggingface-cli login\")\n        return []\n\n    if difficulty and \"difficulty\" in ds.column_names:\n        ds = ds.filter(lambda x: x[\"difficulty\"] == difficulty)\n        logger.info(f\"Filtered to {len(ds)} {difficulty} problems\")\n\n    total = len(ds)\n    if limit and limit < total:\n        ds = ds.select(range(limit))\n        total = limit\n\n    logger.info(\n        f\"Evaluating {benchmark_name}: {total} questions, concurrency={concurrency}, \"\n        f\"temperature={temperature}, max_tokens={max_tokens}\"\n    )\n\n    if config.kind == \"code\":\n        logger.warning(\n            \"Code benchmarks execute model-generated code. Use a sandboxed environment.\"\n        )\n\n    semaphore = asyncio.Semaphore(concurrency)\n    results: list[QuestionResult | None] = [None] * total\n    completed = 0\n    lock = asyncio.Lock()\n\n    async def process_question(\n        idx: int, doc: dict, http_client: httpx.AsyncClient\n    ) -> None:\n        nonlocal completed\n        system_msg = None\n\n        if benchmark_name == \"gpqa_diamond\":\n            prompt, gold = format_gpqa_question(doc, idx)\n            valid_letters = \"ABCD\"\n        elif benchmark_name == \"mmlu_pro\":\n            prompt, gold = format_mmlu_pro_question(doc)\n            valid_letters = \"ABCDEFGHIJ\"[: len(doc[\"options\"])]\n        elif benchmark_name.startswith(\"aime\"):\n            prompt, gold_int = format_aime_question(doc)\n            gold = str(gold_int)\n        elif benchmark_name == \"humaneval\":\n            prompt, exec_meta = format_humaneval_question(doc)\n            gold = \"pass\"\n        elif benchmark_name == \"livecodebench\":\n            prompt, system_msg, exec_meta = format_livecodebench_question(doc)\n            gold = \"pass\"\n        else:\n            raise ValueError(f\"Unknown benchmark: {benchmark_name}\")\n\n        async with semaphore:\n            t0 = time.monotonic()\n            api_result = await call_with_retries(\n                http_client,\n                base_url,\n                model,\n                prompt,\n                temperature,\n                max_tokens,\n                timeout,\n                system_message=system_msg,\n                reasoning_effort=reasoning_effort,\n                top_p=top_p,\n            )\n            elapsed = time.monotonic() - t0\n\n        if api_result is None:\n            result = QuestionResult(\n                question_id=idx,\n                prompt=prompt,\n                response=\"\",\n                extracted_answer=None,\n                gold_answer=gold,\n                correct=False,\n                error=\"API failure after retries\",\n                elapsed_s=elapsed,\n            )\n        else:\n            response = api_result.content\n            stats = {\n                \"prompt_tokens\": api_result.prompt_tokens,\n                \"completion_tokens\": api_result.completion_tokens,\n                \"reasoning_tokens\": api_result.reasoning_tokens,\n                \"elapsed_s\": elapsed,\n            }\n\n            if config.kind == \"mc\":\n                extracted = extract_mc_answer(response, valid_letters)\n                result = QuestionResult(\n                    question_id=idx,\n                    prompt=prompt,\n                    response=response,\n                    extracted_answer=extracted,\n                    gold_answer=gold,\n                    correct=(extracted == gold) if extracted else False,\n                    **stats,\n                )\n            elif config.kind == \"math\":\n                extracted = extract_boxed_answer(response)\n                correct = (\n                    check_aime_answer(extracted, int(gold)) if extracted else False\n                )\n                result = QuestionResult(\n                    question_id=idx,\n                    prompt=prompt,\n                    response=response,\n                    extracted_answer=extracted,\n                    gold_answer=gold,\n                    correct=correct,\n                    **stats,\n                )\n            elif config.kind == \"code\":\n                # HumanEval needs preserved indentation (function body completion)\n                keep_indent = benchmark_name == \"humaneval\"\n                code = extract_code_block(response, preserve_indent=keep_indent)\n                if code is None:\n                    result = QuestionResult(\n                        question_id=idx,\n                        prompt=prompt,\n                        response=response,\n                        extracted_answer=None,\n                        gold_answer=gold,\n                        correct=False,\n                        error=\"No code block extracted\",\n                        **stats,\n                    )\n                elif benchmark_name == \"humaneval\":\n                    passed, diag = run_humaneval_test(\n                        exec_meta[\"problem\"],\n                        code,\n                    )\n                    result = QuestionResult(\n                        question_id=idx,\n                        prompt=prompt,\n                        response=response,\n                        extracted_answer=\"pass\" if passed else \"fail\",\n                        gold_answer=gold,\n                        correct=passed,\n                        error=diag if not passed else None,\n                        **stats,\n                    )\n                elif benchmark_name == \"livecodebench\":\n                    passed, diag = run_livecodebench_test(\n                        code,\n                        exec_meta[\"sample\"],\n                    )\n                    result = QuestionResult(\n                        question_id=idx,\n                        prompt=prompt,\n                        response=response,\n                        extracted_answer=\"pass\" if passed else \"fail\",\n                        gold_answer=gold,\n                        correct=passed,\n                        error=diag if not passed else None,\n                        **stats,\n                    )\n                else:\n                    result = QuestionResult(\n                        question_id=idx,\n                        prompt=prompt,\n                        response=response,\n                        extracted_answer=None,\n                        gold_answer=gold,\n                        correct=False,\n                        error=\"Unknown code benchmark\",\n                        **stats,\n                    )\n            else:\n                result = QuestionResult(\n                    question_id=idx,\n                    prompt=prompt,\n                    response=response,\n                    extracted_answer=None,\n                    gold_answer=gold,\n                    correct=False,\n                    error=\"Unsupported kind\",\n                    **stats,\n                )\n\n        results[idx] = result\n\n        async with lock:\n            completed += 1\n            n = completed\n        if n % max(1, total // 20) == 0 or n == total:\n            correct_so_far = sum(1 for r in results if r is not None and r.correct)\n            answered = sum(1 for r in results if r is not None)\n            logger.info(\n                f\"  [{n}/{total}] {correct_so_far}/{answered} correct \"\n                f\"({correct_so_far / max(answered, 1):.1%})\"\n            )\n\n    async with httpx.AsyncClient() as http_client:\n        tasks = [process_question(i, doc, http_client) for i, doc in enumerate(ds)]\n        await asyncio.gather(*tasks)\n\n    return [r for r in results if r is not None]\n\n\n# ---------------------------------------------------------------------------\n# Results display\n# ---------------------------------------------------------------------------\n\n\ndef print_results(\n    benchmark_name: str,\n    results: list[QuestionResult],\n    concurrency: int | None = None,\n) -> dict[str, Any]:\n    total = len(results)\n    correct = sum(r.correct for r in results)\n    errors = sum(1 for r in results if r.error)\n    no_extract = sum(1 for r in results if r.extracted_answer is None and not r.error)\n    accuracy = correct / max(total, 1)\n\n    total_prompt_tokens = sum(r.prompt_tokens for r in results)\n    total_completion_tokens = sum(r.completion_tokens for r in results)\n    total_reasoning_tokens = sum(r.reasoning_tokens for r in results)\n    total_elapsed = sum(r.elapsed_s for r in results)\n    wall_clock = max(r.elapsed_s for r in results) if results else 0.0\n    avg_gen_tps = total_completion_tokens / total_elapsed if total_elapsed > 0 else 0.0\n\n    label = f\"[c={concurrency}] \" if concurrency is not None else \"\"\n    print(f\"\\n{label}{benchmark_name}: {correct}/{total} ({accuracy:.1%})\")\n    tok_line = f\"  tokens: {total_prompt_tokens:,} prompt + {total_completion_tokens:,} completion\"\n    if total_reasoning_tokens > 0:\n        tok_line += f\" ({total_reasoning_tokens:,} reasoning)\"\n    tok_line += (\n        f\"  |  avg gen tps: {avg_gen_tps:.1f}\"\n        f\"  |  total time: {total_elapsed:.1f}s  wall clock: {wall_clock:.1f}s\"\n    )\n    print(tok_line)\n    if errors:\n        print(f\"  API errors: {errors}\")\n    if no_extract:\n        print(f\"  No answer extracted: {no_extract}\")\n\n    return {\n        \"benchmark\": benchmark_name,\n        \"accuracy\": accuracy,\n        \"correct\": correct,\n        \"total\": total,\n        \"errors\": errors,\n        \"no_extract\": no_extract,\n        \"total_prompt_tokens\": total_prompt_tokens,\n        \"total_completion_tokens\": total_completion_tokens,\n        \"total_reasoning_tokens\": total_reasoning_tokens,\n        \"total_elapsed_s\": total_elapsed,\n        \"wall_clock_s\": wall_clock,\n        \"avg_gen_tps\": avg_gen_tps,\n    }\n\n\ndef print_comparison(\n    benchmark_name: str,\n    results_by_c: dict[int, list[QuestionResult]],\n) -> None:\n    levels = sorted(results_by_c.keys())\n    print(f\"\\n{'=' * 70}\")\n    print(f\"COMPARISON: {benchmark_name}\")\n    print(f\"{'=' * 70}\")\n\n    header = f\"{'Concurrency':<15} {'Accuracy':>10} {'Correct':>10} {'Total':>10} {'Comp Tokens':>12} {'Wall Clock':>12} {'Avg Gen TPS':>12}\"\n    print(header)\n    print(\"-\" * len(header))\n    for c in levels:\n        r = results_by_c[c]\n        correct = sum(q.correct for q in r)\n        total = len(r)\n        comp_tok = sum(q.completion_tokens for q in r)\n        total_elapsed = sum(q.elapsed_s for q in r)\n        avg_tps = comp_tok / total_elapsed if total_elapsed > 0 else 0.0\n        wall = max(q.elapsed_s for q in r) if r else 0.0\n        print(\n            f\"c={c:<13} {correct / max(total, 1):>10.1%} {correct:>10} {total:>10}\"\n            f\" {comp_tok:>12,} {wall:>11.1f}s {avg_tps:>12.1f}\"\n        )\n\n    if len(levels) >= 2:\n        base_results = results_by_c[levels[0]]\n        test_results = results_by_c[levels[-1]]\n        changed = sum(\n            1\n            for br, tr in zip(base_results, test_results, strict=True)\n            if br.correct != tr.correct\n        )\n        total = min(len(base_results), len(test_results))\n        print(\n            f\"\\nQuestions with different correctness (c={levels[0]} vs c={levels[-1]}): {changed}/{total}\"\n        )\n        if changed == 0:\n            print(\"Batching produced identical quality.\")\n        elif changed <= total * 0.01:\n            print(\"Negligible quality difference from batching.\")\n        else:\n            print(\n                f\"WARNING: {changed / max(total, 1) * 100:.1f}% of questions changed.\"\n            )\n    print()\n\n\n# ---------------------------------------------------------------------------\n# Interactive task picker\n# ---------------------------------------------------------------------------\n\n\ndef pick_tasks_interactive() -> list[str]:\n    import termios\n    import tty\n\n    if not sys.stdin.isatty():\n        logger.error(\"No --tasks specified and stdin is not a terminal.\")\n        return []\n\n    items = [(name, cfg.description) for name, cfg in BENCHMARKS.items()]\n    selected: list[bool] = [False] * len(items)\n    cursor = 0\n    total_lines = len(items) + 4\n\n    def write(s: str) -> None:\n        sys.stdout.write(s)\n\n    def render(first: bool = False) -> None:\n        if not first:\n            write(f\"\\033[{total_lines}A\")\n        write(\"\\033[J\")\n        write(\n            \"\\033[1mSelect benchmarks\\033[0m (up/down, space toggle, enter confirm, q quit)\\r\\n\\r\\n\"\n        )\n        for i, (name, desc) in enumerate(items):\n            marker = \">\" if i == cursor else \" \"\n            check = \"x\" if selected[i] else \" \"\n            line = f\"  {marker} [{check}] {name:<17} {desc}\"\n            write(f\"\\033[7m{line}\\033[0m\\r\\n\" if i == cursor else f\"{line}\\r\\n\")\n        write(f\"\\r\\n  {sum(selected)} selected\\r\\n\")\n        sys.stdout.flush()\n\n    fd = sys.stdin.fileno()\n    old = termios.tcgetattr(fd)\n    try:\n        tty.setraw(fd)\n        write(\"\\033[?25l\")\n        render(first=True)\n        while True:\n            ch = sys.stdin.read(1)\n            if ch in (\"q\", \"\\x03\"):\n                write(\"\\033[?25h\\033[0m\\r\\n\")\n                return []\n            elif ch in (\"\\r\", \"\\n\"):\n                break\n            elif ch == \" \":\n                selected[cursor] = not selected[cursor]\n            elif ch == \"\\x1b\":\n                seq = sys.stdin.read(2)\n                if seq == \"[A\":\n                    cursor = (cursor - 1) % len(items)\n                elif seq == \"[B\":\n                    cursor = (cursor + 1) % len(items)\n            render()\n    finally:\n        termios.tcsetattr(fd, termios.TCSADRAIN, old)\n        write(\"\\033[?25h\\033[0m\\r\\n\")\n        sys.stdout.flush()\n\n    chosen = [name for (name, _), sel in zip(items, selected, strict=True) if sel]\n    if chosen:\n        logger.info(f\"Selected: {', '.join(chosen)}\")\n    return chosen\n\n\n# ---------------------------------------------------------------------------\n# Results persistence\n# ---------------------------------------------------------------------------\n\n\ndef save_results(\n    results_dir: str,\n    benchmark_name: str,\n    model: str,\n    concurrency: int,\n    results: list[QuestionResult],\n    scores: dict[str, Any],\n) -> Path:\n    out_dir = Path(results_dir) / model.replace(\"/\", \"_\") / benchmark_name\n    out_dir.mkdir(parents=True, exist_ok=True)\n    ts = time.strftime(\"%Y%m%d_%H%M%S\")\n    path = out_dir / f\"c{concurrency}_{ts}.json\"\n\n    data = {\n        \"benchmark\": benchmark_name,\n        \"model\": model,\n        \"concurrency\": concurrency,\n        \"scores\": scores,\n        \"results\": [\n            {\n                \"question_id\": r.question_id,\n                \"prompt\": r.prompt,\n                \"response\": r.response,\n                \"extracted_answer\": r.extracted_answer,\n                \"gold_answer\": r.gold_answer,\n                \"correct\": r.correct,\n                \"error\": r.error,\n                \"prompt_tokens\": r.prompt_tokens,\n                \"completion_tokens\": r.completion_tokens,\n                \"reasoning_tokens\": r.reasoning_tokens,\n                \"elapsed_s\": round(r.elapsed_s, 2),\n            }\n            for r in results\n        ],\n    }\n    with open(path, \"w\") as f:\n        json.dump(data, f, indent=2)\n    logger.info(f\"Results saved to {path}\")\n    return path\n\n\n# ---------------------------------------------------------------------------\n# CLI\n# ---------------------------------------------------------------------------\n\n\ndef parse_int_list(values: list[str]) -> list[int]:\n    items: list[int] = []\n    for v in values:\n        for part in v.split(\",\"):\n            if part.strip():\n                items.append(int(part.strip()))\n    return items\n\n\ndef main() -> int:\n    ap = argparse.ArgumentParser(\n        prog=\"exo-eval\",\n        description=\"Quality evaluation for exo — matches Artificial Analysis methodology.\",\n    )\n    add_common_instance_args(ap)\n\n    ap.add_argument(\n        \"--tasks\",\n        default=None,\n        help=\"Comma-separated benchmark names. Omit for interactive picker.\",\n    )\n    ap.add_argument(\n        \"--limit\",\n        type=int,\n        default=None,\n        help=\"Max questions per benchmark (for fast iteration).\",\n    )\n\n    reasoning_group = ap.add_mutually_exclusive_group()\n    reasoning_group.add_argument(\n        \"--reasoning\",\n        action=\"store_true\",\n        default=None,\n        help=\"Force reasoning-model settings (temperature=0.6, max_tokens=65536).\",\n    )\n    reasoning_group.add_argument(\n        \"--no-reasoning\",\n        action=\"store_true\",\n        default=False,\n        help=\"Force non-reasoning settings (temperature=0, max_tokens=16384).\",\n    )\n\n    ap.add_argument(\n        \"--temperature\", type=float, default=None, help=\"Override temperature.\"\n    )\n    ap.add_argument(\"--top-p\", type=float, default=None, help=\"Override top_p.\")\n    ap.add_argument(\n        \"--max-tokens\", type=int, default=None, help=\"Override max output tokens.\"\n    )\n    ap.add_argument(\n        \"--num-concurrent\",\n        type=int,\n        default=1,\n        help=\"Concurrent API requests (default: 1).\",\n    )\n    ap.add_argument(\n        \"--compare-concurrency\",\n        nargs=\"+\",\n        default=None,\n        help=\"Run at multiple concurrency levels and compare. E.g. --compare-concurrency 1,4\",\n    )\n    ap.add_argument(\n        \"--request-timeout\",\n        type=float,\n        default=None,\n        help=\"Per-request timeout in seconds (default: no timeout).\",\n    )\n    ap.add_argument(\n        \"--reasoning-effort\",\n        default=None,\n        choices=[\"low\", \"medium\", \"high\"],\n        help=\"Override reasoning effort (default: 'high' for reasoning models, none for non-reasoning).\",\n    )\n    ap.add_argument(\n        \"--difficulty\",\n        default=None,\n        choices=[\"easy\", \"medium\", \"hard\"],\n        help=\"Filter by difficulty (livecodebench only). E.g. --difficulty hard\",\n    )\n    ap.add_argument(\n        \"--results-dir\",\n        default=\"eval_results\",\n        help=\"Directory for result JSON files (default: eval_results).\",\n    )\n    ap.add_argument(\n        \"--skip-instance-setup\",\n        action=\"store_true\",\n        help=\"Skip exo instance management (assumes model is already running).\",\n    )\n\n    args, _ = ap.parse_known_args()\n\n    # Resolve tasks\n    if args.tasks:\n        task_names = [t.strip() for t in args.tasks.split(\",\") if t.strip()]\n    else:\n        task_names = pick_tasks_interactive()\n    if not task_names:\n        return 0\n\n    for t in task_names:\n        if t not in BENCHMARKS:\n            logger.error(f\"Unknown benchmark '{t}'. Available: {', '.join(BENCHMARKS)}\")\n            return 1\n\n    # Instance management\n    client = ExoClient(args.host, args.port, timeout_s=args.timeout)\n    instance_id: str | None = None\n\n    if not args.skip_instance_setup:\n        short_id, full_model_id = resolve_model_short_id(\n            client,\n            args.model,\n            force_download=args.force_download,\n        )\n        selected = settle_and_fetch_placements(\n            client,\n            full_model_id,\n            args,\n            settle_timeout=args.settle_timeout,\n        )\n        if not selected:\n            logger.error(\"No valid placements matched your filters.\")\n            return 1\n\n        selected.sort(\n            key=lambda p: (\n                str(p.get(\"instance_meta\", \"\")),\n                str(p.get(\"sharding\", \"\")),\n                -nodes_used_in_instance(p[\"instance\"]),\n            ),\n            reverse=True,\n        )\n        preview = selected[0]\n        instance = preview[\"instance\"]\n        instance_id = instance_id_from_instance(instance)\n\n        logger.info(\n            f\"PLACEMENT: {preview['sharding']} / {preview['instance_meta']} / \"\n            f\"nodes={nodes_used_in_instance(instance)}\"\n        )\n\n        settle_deadline = (\n            time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None\n        )\n        download_duration = run_planning_phase(\n            client,\n            full_model_id,\n            preview,\n            args.danger_delete_downloads,\n            args.timeout,\n            settle_deadline,\n        )\n        if download_duration is not None:\n            logger.info(f\"Download: {download_duration:.1f}s\")\n\n        client.request_json(\"POST\", \"/instance\", body={\"instance\": instance})\n        try:\n            wait_for_instance_ready(client, instance_id)\n        except (RuntimeError, TimeoutError) as e:\n            logger.error(f\"Failed to initialize: {e}\")\n            with contextlib.suppress(ExoHttpError):\n                client.request_json(\"DELETE\", f\"/instance/{instance_id}\")\n            return 1\n        time.sleep(1)\n    else:\n        full_model_id = args.model\n\n    # Auto-detect reasoning from model config\n    model_config = load_model_config(full_model_id)\n    if args.reasoning:\n        is_reasoning = True\n    elif args.no_reasoning:\n        is_reasoning = False\n    elif model_config is not None:\n        is_reasoning = model_config.get(\"reasoning\", False)\n        logger.info(\n            f\"Auto-detected from config: {model_config['name']} → \"\n            f\"{'reasoning' if is_reasoning else 'non-reasoning'}\"\n        )\n    else:\n        is_reasoning = False\n        logger.warning(\n            f\"Model '{full_model_id}' not found in eval_configs/models.toml. \"\n            f\"Defaulting to non-reasoning. Use --reasoning to override.\"\n        )\n\n    # Resolve temperature, max_tokens, reasoning_effort\n    # Priority: CLI flag > per-model config > global defaults\n    cfg = model_config or {}\n\n    if args.temperature is not None:\n        temperature = args.temperature\n    elif \"temperature\" in cfg:\n        temperature = float(cfg[\"temperature\"])\n    else:\n        temperature = (\n            TEMPERATURE_REASONING if is_reasoning else TEMPERATURE_NON_REASONING\n        )\n\n    if args.max_tokens is not None:\n        max_tokens = args.max_tokens\n    elif \"max_tokens\" in cfg:\n        max_tokens = int(cfg[\"max_tokens\"])\n    else:\n        max_tokens = REASONING_MAX_TOKENS if is_reasoning else DEFAULT_MAX_TOKENS\n\n    if args.top_p is not None:\n        top_p: float | None = args.top_p\n    elif \"top_p\" in cfg:\n        top_p = float(cfg[\"top_p\"])\n    else:\n        top_p = None  # let server use its default\n\n    if args.reasoning_effort is not None:\n        reasoning_effort = args.reasoning_effort\n    elif \"reasoning_effort\" in cfg:\n        reasoning_effort = str(cfg[\"reasoning_effort\"])\n    else:\n        reasoning_effort = \"high\" if is_reasoning else None\n    base_url = f\"http://{args.host}:{args.port}\"\n\n    logger.info(f\"Model: {full_model_id}\")\n    logger.info(\n        f\"Settings: temperature={temperature}, max_tokens={max_tokens}, \"\n        + (f\"top_p={top_p}, \" if top_p is not None else \"\")\n        + f\"reasoning={'yes' if is_reasoning else 'no'}\"\n        + (f\", reasoning_effort={reasoning_effort}\" if reasoning_effort else \"\")\n    )\n\n    try:\n        if args.compare_concurrency:\n            concurrency_levels = parse_int_list(args.compare_concurrency)\n            for task_name in task_names:\n                results_by_c: dict[int, list[QuestionResult]] = {}\n                for c in concurrency_levels:\n                    logger.info(f\"\\n{'=' * 50}\")\n                    logger.info(f\"Running {task_name} at concurrency={c}\")\n                    results = asyncio.run(\n                        evaluate_benchmark(\n                            task_name,\n                            base_url,\n                            full_model_id,\n                            temperature,\n                            max_tokens,\n                            concurrency=c,\n                            limit=args.limit,\n                            timeout=args.request_timeout,\n                            reasoning_effort=reasoning_effort,\n                            top_p=top_p,\n                            difficulty=args.difficulty,\n                        )\n                    )\n                    if results:\n                        scores = print_results(task_name, results, concurrency=c)\n                        save_results(\n                            args.results_dir,\n                            task_name,\n                            full_model_id,\n                            c,\n                            results,\n                            scores,\n                        )\n                        results_by_c[c] = results\n                if len(results_by_c) >= 2:\n                    print_comparison(task_name, results_by_c)\n        else:\n            for task_name in task_names:\n                results = asyncio.run(\n                    evaluate_benchmark(\n                        task_name,\n                        base_url,\n                        full_model_id,\n                        temperature,\n                        max_tokens,\n                        concurrency=args.num_concurrent,\n                        limit=args.limit,\n                        timeout=args.request_timeout,\n                        reasoning_effort=reasoning_effort,\n                        top_p=top_p,\n                        difficulty=args.difficulty,\n                    )\n                )\n                if results:\n                    scores = print_results(task_name, results)\n                    save_results(\n                        args.results_dir,\n                        task_name,\n                        full_model_id,\n                        args.num_concurrent,\n                        results,\n                        scores,\n                    )\n    finally:\n        if instance_id is not None:\n            try:\n                client.request_json(\"DELETE\", f\"/instance/{instance_id}\")\n            except ExoHttpError as e:\n                if e.status != 404:\n                    raise\n            wait_for_instance_gone(client, instance_id)\n\n    return 0\n\n\nif __name__ == \"__main__\":\n    raise SystemExit(main())\n"
  },
  {
    "path": "bench/harness.py",
    "content": "# type: ignore\nfrom __future__ import annotations\n\nimport argparse\nimport http.client\nimport json\nimport os\nimport time\nfrom typing import Any\nfrom urllib.parse import urlencode\n\nfrom loguru import logger\n\n_SETTLE_INITIAL_BACKOFF_S = 1.0\n_SETTLE_MAX_BACKOFF_S = 60.0\n_SETTLE_BACKOFF_MULTIPLIER = 2.0\n\n\nclass ExoHttpError(RuntimeError):\n    def __init__(self, status: int, reason: str, body_preview: str):\n        super().__init__(f\"HTTP {status} {reason}: {body_preview}\")\n        self.status = status\n\n\nclass ExoClient:\n    def __init__(self, host: str, port: int, timeout_s: float = 7200.0):\n        self.host = host\n        self.port = port\n        self.timeout_s = timeout_s\n\n    def request_json(\n        self,\n        method: str,\n        path: str,\n        params: dict[str, Any] | None = None,\n        body: dict[str, Any] | None = None,\n        headers: dict[str, str] | None = None,\n    ) -> Any:\n        if not path.startswith(\"/\"):\n            path = \"/\" + path\n        if params:\n            path = path + \"?\" + urlencode(params)\n\n        conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)\n        try:\n            payload: bytes | None = None\n            hdrs: dict[str, str] = {\"Accept\": \"application/json\"}\n\n            if body is not None:\n                payload = json.dumps(body).encode(\"utf-8\")\n                hdrs[\"Content-Type\"] = \"application/json\"\n            if headers:\n                hdrs.update(headers)\n\n            conn.request(method.upper(), path, body=payload, headers=hdrs)\n            resp = conn.getresponse()\n            raw = resp.read()\n            text = raw.decode(\"utf-8\", errors=\"replace\") if raw else \"\"\n\n            if resp.status >= 400:\n                raise ExoHttpError(resp.status, resp.reason, text[:300])\n\n            if not text:\n                return None\n            return json.loads(text)\n        finally:\n            conn.close()\n\n    def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:\n        return self.request_json(\"POST\", \"/bench/chat/completions\", body=payload)\n\n\ndef unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:\n    if len(instance) != 1:\n        raise KeyError(f\"Expected 1 key, got keys={list(instance.keys())}\")\n\n    tag = next(iter(instance))\n    inner = instance[tag]\n    if not isinstance(inner, dict):\n        raise TypeError(f\"payload for {tag} must be dict, got {type(inner)}\")\n    return inner\n\n\ndef instance_id_from_instance(instance: dict[str, Any]) -> str:\n    inner = unwrap_instance(instance)\n    return str(inner[\"instanceId\"])\n\n\ndef nodes_used_in_instance(instance: dict[str, Any]) -> int:\n    inner = unwrap_instance(instance)\n    return len(inner[\"shardAssignments\"][\"nodeToRunner\"])\n\n\ndef runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:\n    inner = unwrap_instance(instance)\n    runner_to_shard = inner[\"shardAssignments\"][\"runnerToShard\"]\n    return list(runner_to_shard.keys())\n\n\ndef runner_ready(runner: dict[str, Any]) -> bool:\n    return \"RunnerReady\" in runner\n\n\ndef runner_failed(runner: dict[str, Any]) -> bool:\n    return \"RunnerFailed\" in runner\n\n\ndef get_runner_failed_message(runner: dict[str, Any]) -> str | None:\n    if \"RunnerFailed\" in runner:\n        return runner[\"RunnerFailed\"].get(\"errorMessage\")\n    return None\n\n\ndef wait_for_instance_ready(\n    client: ExoClient, instance_id: str, timeout: float = 24000.0\n) -> None:\n    start_time = time.time()\n    instance_existed = False\n    while time.time() - start_time < timeout:\n        state = client.request_json(\"GET\", \"/state\")\n        instances = state.get(\"instances\", {})\n\n        if instance_id not in instances:\n            if instance_existed:\n                # Instance was deleted after being created - likely due to runner failure\n                raise RuntimeError(\n                    f\"Instance {instance_id} was deleted (runner may have failed)\"\n                )\n            time.sleep(0.1)\n            continue\n\n        instance_existed = True\n        instance = instances[instance_id]\n        runner_ids = runner_ids_from_instance(instance)\n        runners = state.get(\"runners\", {})\n\n        # Check for failed runners first\n        for rid in runner_ids:\n            runner = runners.get(rid, {})\n            if runner_failed(runner):\n                error_msg = get_runner_failed_message(runner) or \"Unknown error\"\n                raise RuntimeError(f\"Runner {rid} failed: {error_msg}\")\n\n        if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):\n            return\n\n        time.sleep(0.1)\n\n    raise TimeoutError(f\"Instance {instance_id} did not become ready within {timeout=}\")\n\n\ndef wait_for_instance_gone(\n    client: ExoClient, instance_id: str, timeout: float = 3.0\n) -> None:\n    start_time = time.time()\n    while time.time() - start_time < timeout:\n        try:\n            client.request_json(\"GET\", f\"/instance/{instance_id}\")\n            time.sleep(0.4)\n        except ExoHttpError as e:\n            if e.status == 404:\n                return\n            raise\n\n    raise TimeoutError(f\"Instance {instance_id} did not get deleted within {timeout=}\")\n\n\ndef resolve_model_short_id(\n    client: ExoClient, model_arg: str, *, force_download: bool = False\n) -> tuple[str, str]:\n    models = client.request_json(\"GET\", \"/models\") or {}\n    data = models.get(\"data\") or []\n\n    for m in data:\n        if (m.get(\"name\") or \"\").lower() == model_arg.lower():\n            short_id = str(m[\"name\"])\n            full_id = str(m.get(\"hugging_face_id\") or m[\"name\"])\n            return short_id, full_id\n\n    for m in data:\n        if m.get(\"hugging_face_id\") == model_arg:\n            short_id = str(m[\"name\"])\n            full_id = str(m[\"hugging_face_id\"])\n            return short_id, full_id\n\n    if force_download and \"/\" in model_arg:\n        logger.info(f\"Model not in /models, adding from HuggingFace: {model_arg}\")\n        result = client.request_json(\n            \"POST\", \"/models/add\", body={\"model_id\": model_arg}\n        )\n        if result:\n            short_id = str(result.get(\"name\") or model_arg.rsplit(\"/\", 1)[-1])\n            full_id = str(result.get(\"hugging_face_id\") or model_arg)\n            return short_id, full_id\n\n    raise ValueError(f\"Model not found in /models: {model_arg}\")\n\n\ndef placement_filter(instance_meta: str, wanted: str) -> bool:\n    s = (instance_meta or \"\").lower()\n    if wanted == \"both\":\n        return (\"ring\" in s) or (\"jaccl\" in s)\n    return wanted in s\n\n\ndef sharding_filter(sharding: str, wanted: str) -> bool:\n    s = (sharding or \"\").lower()\n    if wanted == \"both\":\n        return (\"pipeline\" in s) or (\"tensor\" in s)\n    return wanted in s\n\n\ndef fetch_and_filter_placements(\n    client: ExoClient, full_model_id: str, args: argparse.Namespace\n) -> list[dict[str, Any]]:\n    previews_resp = client.request_json(\n        \"GET\", \"/instance/previews\", params={\"model_id\": full_model_id}\n    )\n    previews = previews_resp.get(\"previews\") or []\n\n    selected: list[dict[str, Any]] = []\n    for p in previews:\n        if p.get(\"error\") is not None:\n            continue\n        if not placement_filter(str(p.get(\"instance_meta\", \"\")), args.instance_meta):\n            continue\n        if not sharding_filter(str(p.get(\"sharding\", \"\")), args.sharding):\n            continue\n\n        instance = p.get(\"instance\")\n        if not isinstance(instance, dict):\n            continue\n\n        n = nodes_used_in_instance(instance)\n        # Skip tensor ring single node as it is pointless when pipeline ring\n        if n == 1 and (\n            (args.sharding == \"both\" and \"tensor\" in p.get(\"sharding\", \"\").lower())\n            or (\n                args.instance_meta == \"both\"\n                and \"jaccl\" in p.get(\"instance_meta\", \"\").lower()\n            )\n        ):\n            continue\n\n        if (\n            args.skip_pipeline_jaccl\n            and (\n                args.instance_meta == \"both\"\n                and \"jaccl\" in p.get(\"instance_meta\", \"\").lower()\n            )\n            and (\n                args.sharding == \"both\" and \"pipeline\" in p.get(\"sharding\", \"\").lower()\n            )\n        ):\n            continue\n\n        if (\n            args.skip_tensor_ring\n            and (\n                args.instance_meta == \"both\"\n                and \"ring\" in p.get(\"instance_meta\", \"\").lower()\n            )\n            and (args.sharding == \"both\" and \"tensor\" in p.get(\"sharding\", \"\").lower())\n        ):\n            continue\n\n        if args.min_nodes <= n <= args.max_nodes:\n            selected.append(p)\n\n    return selected\n\n\ndef settle_and_fetch_placements(\n    client: ExoClient,\n    full_model_id: str,\n    args: argparse.Namespace,\n    settle_timeout: float = 0,\n) -> list[dict[str, Any]]:\n    selected = fetch_and_filter_placements(client, full_model_id, args)\n\n    if not selected and settle_timeout > 0:\n        backoff = _SETTLE_INITIAL_BACKOFF_S\n        deadline = time.monotonic() + settle_timeout\n        while not selected and time.monotonic() < deadline:\n            remaining = deadline - time.monotonic()\n            logger.warning(\n                f\"No valid placements yet (cluster may still be settling). \"\n                f\"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)...\"\n            )\n            time.sleep(min(backoff, remaining))\n            backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)\n            selected = fetch_and_filter_placements(client, full_model_id, args)\n\n    return selected\n\n\ndef run_planning_phase(\n    client: ExoClient,\n    full_model_id: str,\n    preview: dict[str, Any],\n    danger_delete: bool,\n    timeout: float,\n    settle_deadline: float | None,\n) -> float | None:\n    \"\"\"Check disk space and ensure model is downloaded before benchmarking.\n\n    Returns the wall-clock download duration in seconds if a fresh download\n    was needed, or None if the model was already cached on all nodes.\n    \"\"\"\n    # Get model size from /models\n    models = client.request_json(\"GET\", \"/models\") or {}\n    model_bytes = 0\n    for m in models.get(\"data\", []):\n        if m.get(\"hugging_face_id\") == full_model_id:\n            model_bytes = m.get(\"storage_size_megabytes\", 0) * 1024 * 1024\n            break\n\n    if not model_bytes:\n        logger.warning(\n            f\"Could not determine size for {full_model_id}, skipping disk check\"\n        )\n        return None\n\n    # Get nodes from preview\n    inner = unwrap_instance(preview[\"instance\"])\n    node_ids = list(inner[\"shardAssignments\"][\"nodeToRunner\"].keys())\n    runner_to_shard = inner[\"shardAssignments\"][\"runnerToShard\"]\n\n    state = client.request_json(\"GET\", \"/state\")\n    downloads = state.get(\"downloads\", {})\n    node_disk = state.get(\"nodeDisk\", {})\n\n    needs_download = False\n\n    for node_id in node_ids:\n        node_downloads = downloads.get(node_id, [])\n\n        # Check if model already downloaded on this node\n        already_downloaded = any(\n            \"DownloadCompleted\" in p\n            and unwrap_instance(p[\"DownloadCompleted\"][\"shardMetadata\"])[\"modelCard\"][\n                \"modelId\"\n            ]\n            == full_model_id\n            for p in node_downloads\n        )\n        if already_downloaded:\n            continue\n\n        needs_download = True\n\n        # Wait for disk info if settle_deadline is set\n        disk_info = node_disk.get(node_id, {})\n        backoff = _SETTLE_INITIAL_BACKOFF_S\n        while not disk_info and settle_deadline and time.monotonic() < settle_deadline:\n            remaining = settle_deadline - time.monotonic()\n            logger.info(\n                f\"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)...\"\n            )\n            time.sleep(min(backoff, remaining))\n            backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)\n            state = client.request_json(\"GET\", \"/state\")\n            node_disk = state.get(\"nodeDisk\", {})\n            disk_info = node_disk.get(node_id, {})\n\n        if not disk_info:\n            logger.warning(f\"No disk info for {node_id}, skipping space check\")\n            continue\n\n        avail = disk_info.get(\"available\", {}).get(\"inBytes\", 0)\n        if avail >= model_bytes:\n            continue\n\n        if not danger_delete:\n            raise RuntimeError(\n                f\"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, \"\n                f\"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space.\"\n            )\n\n        # Delete from smallest to largest (skip read-only models from EXO_MODELS_PATH)\n        completed = [\n            (\n                unwrap_instance(p[\"DownloadCompleted\"][\"shardMetadata\"])[\"modelCard\"][\n                    \"modelId\"\n                ],\n                p[\"DownloadCompleted\"][\"total\"][\"inBytes\"],\n            )\n            for p in node_downloads\n            if \"DownloadCompleted\" in p\n            and not p[\"DownloadCompleted\"].get(\"readOnly\", False)\n        ]\n        for del_model, size in sorted(completed, key=lambda x: x[1]):\n            logger.info(f\"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)\")\n            client.request_json(\"DELETE\", f\"/download/{node_id}/{del_model}\")\n            avail += size\n            if avail >= model_bytes:\n                break\n\n        if avail < model_bytes:\n            raise RuntimeError(f\"Could not free enough space on {node_id}\")\n\n    # Start downloads (idempotent)\n    download_t0 = time.perf_counter() if needs_download else None\n    for node_id in node_ids:\n        runner_id = inner[\"shardAssignments\"][\"nodeToRunner\"][node_id]\n        shard = runner_to_shard[runner_id]\n        client.request_json(\n            \"POST\",\n            \"/download/start\",\n            body={\n                \"targetNodeId\": node_id,\n                \"shardMetadata\": shard,\n            },\n        )\n        logger.info(f\"Started download on {node_id}\")\n\n    # Wait for downloads\n    start = time.time()\n    while time.time() - start < timeout:\n        state = client.request_json(\"GET\", \"/state\")\n        downloads = state.get(\"downloads\", {})\n        all_done = True\n        for node_id in node_ids:\n            done = any(\n                \"DownloadCompleted\" in p\n                and unwrap_instance(p[\"DownloadCompleted\"][\"shardMetadata\"])[\n                    \"modelCard\"\n                ][\"modelId\"]\n                == full_model_id\n                for p in downloads.get(node_id, [])\n            )\n            failed = [\n                p[\"DownloadFailed\"][\"errorMessage\"]\n                for p in downloads.get(node_id, [])\n                if \"DownloadFailed\" in p\n                and unwrap_instance(p[\"DownloadFailed\"][\"shardMetadata\"])[\"modelCard\"][\n                    \"modelId\"\n                ]\n                == full_model_id\n            ]\n            if failed:\n                raise RuntimeError(f\"Download failed on {node_id}: {failed[0]}\")\n            if not done:\n                all_done = False\n        if all_done:\n            if download_t0 is not None:\n                return time.perf_counter() - download_t0\n            return None\n        time.sleep(1)\n\n    raise TimeoutError(\"Downloads did not complete in time\")\n\n\ndef add_common_instance_args(ap: argparse.ArgumentParser) -> None:\n    ap.add_argument(\"--host\", default=os.environ.get(\"EXO_HOST\", \"localhost\"))\n    ap.add_argument(\n        \"--port\", type=int, default=int(os.environ.get(\"EXO_PORT\", \"52415\"))\n    )\n    ap.add_argument(\"--model\", required=True, help=\"Model short id or huggingface id\")\n    ap.add_argument(\n        \"--force-download\",\n        action=\"store_true\",\n        help=\"If model not in /models, add it from HuggingFace via exo and download.\",\n    )\n    ap.add_argument(\n        \"--max-nodes\",\n        type=int,\n        default=4,\n        help=\"Only consider placements using <= this many nodes.\",\n    )\n    ap.add_argument(\n        \"--min-nodes\",\n        type=int,\n        default=1,\n        help=\"Only consider placements using >= this many nodes.\",\n    )\n    ap.add_argument(\n        \"--instance-meta\", choices=[\"ring\", \"jaccl\", \"both\"], default=\"both\"\n    )\n    ap.add_argument(\n        \"--sharding\", choices=[\"pipeline\", \"tensor\", \"both\"], default=\"both\"\n    )\n    ap.add_argument(\n        \"--skip-pipeline-jaccl\",\n        action=\"store_true\",\n        help=\"Skip pipeline+jaccl placements, as it's often pointless.\",\n    )\n    ap.add_argument(\n        \"--skip-tensor-ring\",\n        action=\"store_true\",\n        help=\"Skip tensor+ring placements, as it's so slow.\",\n    )\n    ap.add_argument(\n        \"--timeout\", type=float, default=7200.0, help=\"HTTP timeout (seconds).\"\n    )\n    ap.add_argument(\n        \"--settle-timeout\",\n        type=float,\n        default=0,\n        help=\"Max seconds to wait for the cluster to produce valid placements (0 = try once).\",\n    )\n    ap.add_argument(\n        \"--danger-delete-downloads\",\n        action=\"store_true\",\n        help=\"Delete existing models from smallest to largest to make room for benchmark model.\",\n    )\n"
  },
  {
    "path": "bench/parallel_requests.py",
    "content": "# type: ignore\nimport argparse\nimport asyncio\nimport sys\nimport termios\nimport time\nimport tty\n\nimport aiohttp\n\nNUM_REQUESTS = 10\nBASE_URL = \"\"\n\nQUESTIONS = [\n    \"What is the capital of Australia?\",\n    \"How many bones are in the human body?\",\n    \"What year did World War II end?\",\n    \"What is the speed of light in meters per second?\",\n    \"Who wrote Romeo and Juliet?\",\n    \"What is the chemical formula for water?\",\n    \"How many planets are in our solar system?\",\n    \"What is the largest ocean on Earth?\",\n    \"Who painted the Mona Lisa?\",\n    \"What is the boiling point of water in Celsius?\",\n]\n\n\ndef write(s: str) -> None:\n    sys.stdout.write(s)\n\n\n# ---------------------------------------------------------------------------\n# Model picker (same style as exo_eval)\n# ---------------------------------------------------------------------------\n\n\ndef fetch_models() -> list[str]:\n    import json\n    import urllib.request\n\n    with urllib.request.urlopen(f\"{BASE_URL}/state\") as resp:\n        data = json.loads(resp.read())\n    model_ids: set[str] = set()\n    for instance in data.get(\"instances\", {}).values():\n        for variant in instance.values():\n            sa = variant.get(\"shardAssignments\", {})\n            model_id = sa.get(\"modelId\")\n            if model_id:\n                model_ids.add(model_id)\n    return sorted(model_ids)\n\n\ndef pick_model() -> str | None:\n    models = fetch_models()\n    if not models:\n        print(\"No models found.\")\n        return None\n\n    cursor = 0\n    total_lines = len(models) + 4\n\n    def render(first: bool = False) -> None:\n        if not first:\n            write(f\"\\033[{total_lines}A\")\n        write(\"\\033[J\")\n        write(\"\\033[1mSelect model\\033[0m (up/down, enter confirm, q quit)\\r\\n\\r\\n\")\n        for i, model in enumerate(models):\n            line = f\"  {'>' if i == cursor else ' '} {model}\"\n            write(f\"\\033[7m{line}\\033[0m\\r\\n\" if i == cursor else f\"{line}\\r\\n\")\n        write(\"\\r\\n\")\n        sys.stdout.flush()\n\n    fd = sys.stdin.fileno()\n    old = termios.tcgetattr(fd)\n    try:\n        tty.setraw(fd)\n        write(\"\\033[?25l\")\n        render(first=True)\n        while True:\n            ch = sys.stdin.read(1)\n            if ch in (\"q\", \"\\x03\"):\n                write(\"\\033[?25h\\033[0m\\r\\n\")\n                return None\n            elif ch in (\"\\r\", \"\\n\"):\n                break\n            elif ch == \"\\x1b\":\n                seq = sys.stdin.read(2)\n                if seq == \"[A\":\n                    cursor = (cursor - 1) % len(models)\n                elif seq == \"[B\":\n                    cursor = (cursor + 1) % len(models)\n            render()\n    finally:\n        termios.tcsetattr(fd, termios.TCSADRAIN, old)\n        write(f\"\\033[{total_lines}A\\033[J\")  # clear picker UI\n        write(\"\\033[?25h\\033[0m\")\n        sys.stdout.flush()\n\n    return models[cursor]\n\n\n# ---------------------------------------------------------------------------\n# Parallel requests\n# ---------------------------------------------------------------------------\n\nstatuses: list[str] = []\ntimes: list[str] = []\npreviews: list[str] = []\ntokens: list[str] = []\nfull_responses: list[dict | None] = []\ntotal_lines = 0\nstart_time: float = 0\nselected_model: str = \"\"\n\n\ndef render_progress(first: bool = False) -> None:\n    if not first:\n        write(f\"\\033[{total_lines}A\")\n    write(\"\\033[J\")\n    elapsed = time.monotonic() - start_time if start_time else 0\n    done = sum(1 for s in statuses if s == \"done\")\n    write(\n        f\"\\033[1m{selected_model}\\033[0m  [{done}/{NUM_REQUESTS}]  {elapsed:.1f}s\\r\\n\\r\\n\"\n    )\n\n    for i in range(NUM_REQUESTS):\n        q = QUESTIONS[i % len(QUESTIONS)]\n        status = statuses[i]\n        if status == \"pending\":\n            color = \"\\033[33m\"  # yellow\n        elif status == \"running\":\n            color = \"\\033[36m\"  # cyan\n        elif status == \"done\":\n            color = \"\\033[32m\"  # green\n        else:\n            color = \"\\033[31m\"  # red\n        write(\n            f\"  {i:>2}  {color}{status:<8}\\033[0m  {times[i]:>6}  {tokens[i]:>5}tok  {q[:40]:<40}  {previews[i][:50]}\\r\\n\"\n        )\n\n    write(\"\\r\\n\")\n    sys.stdout.flush()\n\n\nasync def send_request(\n    session: aiohttp.ClientSession, i: int, lock: asyncio.Lock\n) -> None:\n    payload = {\n        \"model\": selected_model,\n        \"messages\": [{\"role\": \"user\", \"content\": QUESTIONS[i % len(QUESTIONS)]}],\n        \"max_tokens\": 1024,\n    }\n    statuses[i] = \"running\"\n    async with lock:\n        render_progress()\n    t0 = time.monotonic()\n    try:\n        async with session.post(\n            f\"{BASE_URL}/v1/chat/completions\", json=payload\n        ) as resp:\n            data = await resp.json()\n            elapsed = time.monotonic() - t0\n            full_responses[i] = data\n            times[i] = f\"{elapsed:.1f}s\"\n            if resp.status == 200:\n                choice = data[\"choices\"][0]\n                msg = choice[\"message\"]\n                content = msg.get(\"content\", \"\")\n                previews[i] = content[:50].replace(\"\\n\", \" \") or \"(empty)\"\n                if \"usage\" in data:\n                    tokens[i] = str(data[\"usage\"].get(\"total_tokens\", \"\"))\n                statuses[i] = \"done\"\n            else:\n                statuses[i] = f\"err:{resp.status}\"\n                previews[i] = str(data.get(\"error\", {}).get(\"message\", \"\"))[:50]\n    except Exception as e:\n        elapsed = time.monotonic() - t0\n        times[i] = f\"{elapsed:.1f}s\"\n        statuses[i] = \"error\"\n        previews[i] = str(e)[:50]\n    async with lock:\n        render_progress()\n\n\nasync def run_requests(print_stdout: bool = False) -> None:\n    global start_time, total_lines, statuses, times, previews, tokens, full_responses\n\n    statuses = [\"pending\"] * NUM_REQUESTS\n    times = [\"-\"] * NUM_REQUESTS\n    previews = [\"-\"] * NUM_REQUESTS\n    tokens = [\"-\"] * NUM_REQUESTS\n    full_responses = [None] * NUM_REQUESTS\n    total_lines = NUM_REQUESTS + 4\n\n    write(\"\\033[?25l\")  # hide cursor\n    start_time = time.monotonic()\n    render_progress(first=True)\n    lock = asyncio.Lock()\n    try:\n        async with aiohttp.ClientSession() as session:\n            tasks = [send_request(session, i, lock) for i in range(NUM_REQUESTS)]\n            await asyncio.gather(*tasks)\n        total = time.monotonic() - start_time\n        write(\n            f\"\\033[1m=== All {NUM_REQUESTS} requests done in {total:.1f}s ===\\033[0m\\r\\n\\r\\n\"\n        )\n\n        if print_stdout:\n            for i in range(NUM_REQUESTS):\n                data = full_responses[i]\n                if not data or \"choices\" not in data:\n                    continue\n                choice = data[\"choices\"][0]\n                msg = choice[\"message\"]\n                q = QUESTIONS[i % len(QUESTIONS)]\n                write(f\"\\033[1m--- #{i}: {q} ---\\033[0m\\r\\n\")\n                if msg.get(\"reasoning_content\"):\n                    write(f\"\\033[2m[Thinking]: {msg['reasoning_content']}\\033[0m\\r\\n\")\n                write(f\"{msg.get('content', '')}\\r\\n\")\n                if \"usage\" in data:\n                    u = data[\"usage\"]\n                    write(\n                        f\"\\033[2m[Usage: prompt={u.get('prompt_tokens')}, \"\n                        f\"completion={u.get('completion_tokens')}, \"\n                        f\"total={u.get('total_tokens')}]\\033[0m\\r\\n\"\n                    )\n                write(\"\\r\\n\")\n    finally:\n        write(\"\\033[?25h\")  # show cursor\n        sys.stdout.flush()\n\n\ndef main() -> None:\n    global selected_model, BASE_URL\n    parser = argparse.ArgumentParser(\n        description=\"Send parallel requests to an exo cluster\"\n    )\n    parser.add_argument(\n        \"--host\", required=True, help=\"Hostname of the exo node (e.g. s1)\"\n    )\n    parser.add_argument(\"--port\", type=int, default=52415, help=\"Port (default: 52415)\")\n    parser.add_argument(\n        \"--stdout\", action=\"store_true\", help=\"Print full responses after completion\"\n    )\n    args = parser.parse_args()\n    BASE_URL = f\"http://{args.host}:{args.port}\"\n    model = pick_model()\n    if not model:\n        return\n    selected_model = model\n    asyncio.run(run_requests(print_stdout=args.stdout))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bench/pyproject.toml",
    "content": "[project]\nname = \"exo-bench\"\nversion = \"0.1.0\"\ndescription = \"Benchmarking tool for exo distributed inference\"\nrequires-python = \">=3.13\"\ndependencies = [\n  \"httpx>=0.27.0\",\n  \"loguru>=0.7.3\",\n  \"transformers>=5.0.0\",\n  \"huggingface-hub>=0.33.4\",\n  \"tiktoken>=0.12.0\",\n  \"jinja2>=3.1.0\",\n  \"protobuf>=5.29.0\",\n  \"datasets>=2.0.0\",\n  \"math-verify>=0.7.0\",\n  \"lm-eval[api,math]>=0.4.0\",\n  \"human-eval>=1.0.3\",\n  \"numpy>=1.24.0\",\n]\n\n[build-system]\nrequires = [\"hatchling\"]\nbuild-backend = \"hatchling.build\"\n"
  },
  {
    "path": "bench/scenarios.toml",
    "content": "# Tool definitions — each becomes an OpenAI function tool.\n# All scenarios get all tools unless they specify a `tools` list.\n\n[tools.get_current_weather]\ndescription = \"Get the current weather in a given location\"\nrequired = [\"location\"]\n\n[tools.get_current_weather.properties.location]\ntype = \"string\"\ndescription = \"City and state, e.g. San Francisco, CA\"\n\n[tools.get_current_weather.properties.unit]\ntype = \"string\"\nenum = [\"celsius\", \"fahrenheit\"]\ndescription = \"Temperature unit\"\n\n[tools.calculate]\ndescription = \"Evaluate a mathematical expression and return the numeric result\"\nrequired = [\"expression\"]\n\n[tools.calculate.properties.expression]\ntype = \"string\"\ndescription = \"The math expression to evaluate, e.g. '2 + 3 * 4'\"\n\n[tools.search_products]\ndescription = \"Search for products in a catalog by query, category, and price\"\nrequired = [\"query\"]\n\n[tools.search_products.properties.query]\ntype = \"string\"\ndescription = \"Search query string\"\n\n[tools.search_products.properties.category]\ntype = \"string\"\nenum = [\"electronics\", \"clothing\", \"food\", \"books\"]\ndescription = \"Product category to filter by\"\n\n[tools.search_products.properties.max_price]\ntype = \"number\"\ndescription = \"Maximum price in USD\"\n\n[tools.create_todos]\ndescription = \"Create a structured todo list\"\nrequired = [\"todos\"]\n\n[tools.create_todos.properties.todos]\ntype = \"array\"\ndescription = \"List of todo items\"\n\n[tools.create_todos.properties.todos.items]\ntype = \"object\"\nrequired = [\"content\", \"status\", \"priority\"]\n\n[tools.create_todos.properties.todos.items.properties.content]\ntype = \"string\"\ndescription = \"The todo item text\"\n\n[tools.create_todos.properties.todos.items.properties.status]\ntype = \"string\"\ndescription = \"Status: pending, in_progress, or completed\"\n\n[tools.create_todos.properties.todos.items.properties.priority]\ntype = \"string\"\ndescription = \"Priority: low, normal, or high\"\n\n# -- Should call a tool --\n\n[[scenarios]]\nname = \"weather_simple\"\ndescription = \"Basic weather query -> get_current_weather\"\nexpect_tool_call = true\nexpected_function = \"get_current_weather\"\nrequired_arg_keys = [\"location\"]\n\n[[scenarios.messages]]\nrole = \"user\"\ncontent = \"What's the weather like in Tokyo right now?\"\n\n[[scenarios]]\nname = \"calculator_simple\"\ndescription = \"Math question -> calculate\"\nexpect_tool_call = true\nexpected_function = \"calculate\"\nrequired_arg_keys = [\"expression\"]\n\n[[scenarios.messages]]\nrole = \"user\"\ncontent = \"Use the calculator to compute 3847 * 926 + 17293\"\n\n[[scenarios]]\nname = \"search_with_filters\"\ndescription = \"Product search with category and price filter\"\nexpect_tool_call = true\nexpected_function = \"search_products\"\nrequired_arg_keys = [\"query\"]\n\n[[scenarios.messages]]\nrole = \"user\"\ncontent = \"Find me electronics under $50\"\n\n# -- Multi-turn: tool call then follow-up --\n\n[[scenarios]]\nname = \"weather_multi_turn\"\ndescription = \"Weather query -> tool result -> natural language summary\"\nexpect_tool_call = true\nexpected_function = \"get_current_weather\"\nrequired_arg_keys = [\"location\"]\n\n[scenarios.tool_result]\ntemperature = \"18C\"\ncondition = \"partly cloudy\"\nhumidity = \"65%\"\nwind = \"12 km/h NW\"\n\n[[scenarios.messages]]\nrole = \"user\"\ncontent = \"What's the weather in Paris?\"\n\n[[scenarios]]\nname = \"calculator_multi_turn\"\ndescription = \"Math query -> tool result -> model reports the answer\"\nexpect_tool_call = true\nexpected_function = \"calculate\"\nrequired_arg_keys = [\"expression\"]\n\n[scenarios.tool_result]\nresult = 491682\n\n[[scenarios.messages]]\nrole = \"user\"\ncontent = \"Use the calculator to compute 1847 * 263 + 5921\"\n\n[[scenarios]]\nname = \"search_multi_turn\"\ndescription = \"Search query -> tool result -> model summarizes products\"\nexpect_tool_call = true\nexpected_function = \"search_products\"\nrequired_arg_keys = [\"query\"]\n\n[[scenarios.tool_result.results]]\nname = \"Hands-On Machine Learning\"\nprice = 45.99\nrating = 4.8\n\n[[scenarios.tool_result.results]]\nname = \"Deep Learning with Python\"\nprice = 39.99\nrating = 4.6\n\n[[scenarios.messages]]\nrole = \"user\"\ncontent = \"Search for books about machine learning\"\n\n# -- Sequential tool calls --\n\n[[scenarios]]\nname = \"chained_tool_calls_same\"\ndescription = \"Thinking + weather(Tokyo) -> result -> model must call weather(London)\"\nexpect_tool_call = true\nexpected_function = \"get_current_weather\"\nrequired_arg_keys = [\"location\"]\n\n[[scenarios.messages]]\nrole = \"user\"\ncontent = \"Compare the weather in Tokyo and London.\"\n\n[[scenarios.messages]]\nrole = \"assistant\"\ncontent = \"I'll check both cities. Let me start with Tokyo.\"\n\n[[scenarios.messages.tool_calls]]\nid = \"call_1\"\nname = \"get_current_weather\"\narguments = { location = \"Tokyo\" }\n\n[[scenarios.messages]]\nrole = \"tool\"\ntool_call_id = \"call_1\"\ncontent = '{\"temperature\": \"25C\", \"condition\": \"sunny\"}'\n\n[[scenarios]]\nname = \"chained_tool_calls_different\"\ndescription = \"Thinking + weather(Berlin) -> result -> model must call calculator\"\nexpect_tool_call = true\nexpected_function = \"calculate\"\nrequired_arg_keys = [\"expression\"]\n\n[[scenarios.messages]]\nrole = \"user\"\ncontent = \"What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291.\"\n\n[[scenarios.messages]]\nrole = \"assistant\"\ncontent = \"I'll handle both. Let me check Berlin's weather first.\"\n\n[[scenarios.messages.tool_calls]]\nid = \"call_2\"\nname = \"get_current_weather\"\narguments = { location = \"Berlin\" }\n\n[[scenarios.messages]]\nrole = \"tool\"\ntool_call_id = \"call_2\"\ncontent = '{\"temperature\": \"12C\", \"condition\": \"rainy\"}'\n\n[[scenarios]]\nname = \"chained_tool_calls_three\"\ndescription = \"Two prior thinking+tool calls -> results -> model must make a third\"\nexpect_tool_call = true\nexpected_function = \"get_current_weather\"\nrequired_arg_keys = [\"location\"]\n\n[[scenarios.messages]]\nrole = \"user\"\ncontent = \"Compare weather in Tokyo, Paris, and London.\"\n\n[[scenarios.messages]]\nrole = \"assistant\"\ncontent = \"I'll check all three cities. Starting with Tokyo.\"\n\n[[scenarios.messages.tool_calls]]\nid = \"call_3\"\nname = \"get_current_weather\"\narguments = { location = \"Tokyo\" }\n\n[[scenarios.messages]]\nrole = \"tool\"\ntool_call_id = \"call_3\"\ncontent = '{\"temperature\": \"25C\", \"condition\": \"sunny\"}'\n\n[[scenarios.messages]]\nrole = \"assistant\"\ncontent = \"Got Tokyo. Now checking Paris.\"\n\n[[scenarios.messages.tool_calls]]\nid = \"call_4\"\nname = \"get_current_weather\"\narguments = { location = \"Paris\" }\n\n[[scenarios.messages]]\nrole = \"tool\"\ntool_call_id = \"call_4\"\ncontent = '{\"temperature\": \"18C\", \"condition\": \"cloudy\"}'\n\n# -- Nested object schema (regression for lossy chat template rendering) --\n\n[[scenarios]]\nname = \"nested_schema_tool_call\"\ndescription = \"Tool call with nested object array schema -> create_todos\"\nexpect_tool_call = true\nexpected_function = \"create_todos\"\nrequired_arg_keys = [\"todos\"]\nnested_array_key = \"todos\"\nrequired_item_keys = [\"content\", \"status\", \"priority\"]\ntools = [\"create_todos\"]\n\n[[scenarios.messages]]\nrole = \"user\"\ncontent = \"Create a todo list with 3 items to learn Python\"\n\n# -- Tool name integrity (regression for harmony token leaking into name) --\n\n[tools.glob]\ndescription = \"Search for files matching a glob pattern in the codebase\"\nrequired = [\"pattern\"]\n\n[tools.glob.properties.pattern]\ntype = \"string\"\ndescription = \"The glob pattern to match files against, e.g. '**/*.py'\"\n\n[tools.glob.properties.path]\ntype = \"string\"\ndescription = \"The directory to search in\"\n\n[[scenarios]]\nname = \"tool_name_integrity\"\ndescription = \"Tool name must not contain harmony tokens like <|channel|>\"\nexpect_tool_call = true\nexpected_function = \"glob\"\nrequired_arg_keys = [\"pattern\"]\ntools = [\"glob\"]\n\n[[scenarios.messages]]\nrole = \"user\"\ncontent = \"Find all Python files in the src directory\"\n\n# -- Should NOT call a tool --\n\n[[scenarios]]\nname = \"no_tool_joke\"\ndescription = \"Joke request should NOT trigger any tool\"\nexpect_tool_call = false\n\n[[scenarios.messages]]\nrole = \"user\"\ncontent = \"Tell me a funny joke about cats.\"\n\n[[scenarios]]\nname = \"no_tool_factual\"\ndescription = \"Factual question answerable from training data\"\nexpect_tool_call = false\n\n[[scenarios.messages]]\nrole = \"user\"\ncontent = \"What is the capital of Japan?\"\n"
  },
  {
    "path": "bench/single-m3-ultra.toml",
    "content": "# Single-node M3 Ultra benchmarks\n#\n# Shared constraints applied to ALL benchmarks in this file.\nconstraints = [\n  \"All(MacOsBuild(=25D125))\",\n  \"Hosts(=1)\",\n  \"All(Chip(m3_ultra))\",\n  \"All(GpuCores(=80))\",\n]\n\n[topology]\ntype = \"none\"\n\n# Default args merged into each benchmark's args (benchmark-level args win).\n[defaults]\npp = [512, 2048, 8192, 16384]\ntg = 128\n\n[[benchmark]]\nmodel = \"mlx-community/Meta-Llama-3.1-70B-Instruct-4bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/gpt-oss-120b-MXFP4-Q8\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/GLM-4.7-Flash-8bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-Coder-Next-6bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-30B-A3B-8bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-0.6B-4bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-0.6B-8bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Llama-3.2-1B-Instruct-4bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Llama-3.2-3B-Instruct-4bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Llama-3.2-3B-Instruct-8bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Meta-Llama-3.1-8B-Instruct-4bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Meta-Llama-3.1-8B-Instruct-8bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Meta-Llama-3.1-8B-Instruct-bf16\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/gpt-oss-20b-MXFP4-Q8\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-30B-A3B-4bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/GLM-4.7-Flash-4bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/GLM-4.7-Flash-5bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/GLM-4.7-Flash-6bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Llama-3.3-70B-Instruct-4bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-Coder-Next-4bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-Coder-Next-5bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-Coder-Next-8bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit\"\nextra_constraints = [\"All(Memory(>=96GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Llama-3.3-70B-Instruct-8bit\"\nextra_constraints = [\"All(Memory(>=256GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/llama-3.3-70b-instruct-fp16\"\nextra_constraints = [\"All(Memory(>=256GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/GLM-4.5-Air-8bit\"\nextra_constraints = [\"All(Memory(>=256GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/GLM-4.5-Air-bf16\"\nextra_constraints = [\"All(Memory(>=256GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/GLM-4.7-4bit\"\nextra_constraints = [\"All(Memory(>=256GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/MiniMax-M2.1-3bit\"\nextra_constraints = [\"All(Memory(>=256GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/MiniMax-M2.1-8bit\"\nextra_constraints = [\"All(Memory(>=256GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit\"\nextra_constraints = [\"All(Memory(>=256GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-Coder-Next-bf16\"\nextra_constraints = [\"All(Memory(>=256GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Step-3.5-Flash-4bit\"\nextra_constraints = [\"All(Memory(>=256GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Step-3.5-Flash-6bit\"\nextra_constraints = [\"All(Memory(>=256GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Step-3.5-Flash-8Bit\"\nextra_constraints = [\"All(Memory(>=256GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/DeepSeek-V3.1-4bit\"\nextra_constraints = [\"All(Memory(>=512GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/GLM-4.7-6bit\"\nextra_constraints = [\"All(Memory(>=512GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/GLM-4.7-8bit-gs32\"\nextra_constraints = [\"All(Memory(>=512GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit\"\nextra_constraints = [\"All(Memory(>=512GiB))\"]\n\n[[benchmark]]\nmodel = \"mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit\"\nextra_constraints = [\"All(Memory(>=512GiB))\"]\n"
  },
  {
    "path": "bench/src/exo_bench/__init__.py",
    "content": ""
  },
  {
    "path": "bench/vendor/__init__.py",
    "content": ""
  },
  {
    "path": "bench/vendor/lcb_testing_util.py",
    "content": "# type: ignore\n# Vendored from LiveCodeBench (https://github.com/LiveCodeBench/LiveCodeBench)\n# File: lcb_runner/evaluation/testing_util.py\n# License: MIT\n# Vendored 2026-03-07 — do not modify without updating from upstream.\n\nimport ast\nimport faulthandler\nimport json\nimport platform\n\n# to run the solution files we're using a timing based approach\nimport signal\nimport sys\nimport time\n\n# used for debugging to time steps\nfrom datetime import datetime\nfrom decimal import Decimal\nfrom enum import Enum\nfrom io import StringIO\n\n# from pyext import RuntimeModule\nfrom types import ModuleType\n\n# used for testing the code that reads from input\nfrom unittest.mock import mock_open, patch\n\nimport_string = \"from string import *\\nfrom re import *\\nfrom datetime import *\\nfrom collections import *\\nfrom heapq import *\\nfrom bisect import *\\nfrom copy import *\\nfrom math import *\\nfrom random import *\\nfrom statistics import *\\nfrom itertools import *\\nfrom functools import *\\nfrom operator import *\\nfrom io import *\\nfrom sys import *\\nfrom json import *\\nfrom builtins import *\\nfrom typing import *\\nimport string\\nimport re\\nimport datetime\\nimport collections\\nimport heapq\\nimport bisect\\nimport copy\\nimport math\\nimport random\\nimport statistics\\nimport itertools\\nimport functools\\nimport operator\\nimport io\\nimport sys\\nimport json\\nsys.setrecursionlimit(50000)\\n\"\n\n\ndef truncatefn(s, length=300):\n    if isinstance(s, str):\n        pass\n    else:\n        s = str(s)\n    if len(s) <= length:\n        return s\n\n    return s[: length // 2] + \"...(truncated) ...\" + s[-length // 2 :]\n\n\nclass CODE_TYPE(Enum):\n    call_based = 0\n    standard_input = 1\n\n\n# stuff for setting up signal timer\nclass TimeoutException(Exception):\n    pass\n\n\ndef timeout_handler(signum, frame):\n    print(\"timeout occured: alarm went off\")\n    raise TimeoutException\n\n\n# used to capture stdout as a list\n# from https://stackoverflow.com/a/16571630/6416660\n# alternative use redirect_stdout() from contextlib\nclass Capturing(list):\n    def __enter__(self):\n        self._stdout = sys.stdout\n        sys.stdout = self._stringio = StringIO()\n        # Make closing the StringIO a no-op\n        self._stringio.close = lambda x: 1\n        return self\n\n    def __exit__(self, *args):\n        self.append(self._stringio.getvalue())\n        del self._stringio  # free up some memory\n        sys.stdout = self._stdout\n\n\n# Custom mock for sys.stdin that supports buffer attribute\nclass MockStdinWithBuffer:\n    def __init__(self, inputs: str):\n        self.inputs = inputs\n        self._stringio = StringIO(inputs)\n        self.buffer = MockBuffer(inputs)\n\n    def read(self, *args):\n        return self.inputs\n\n    def readline(self, *args):\n        return self._stringio.readline(*args)\n\n    def readlines(self, *args):\n        return self.inputs.split(\"\\n\")\n\n    def __getattr__(self, name):\n        # Delegate other attributes to StringIO\n        return getattr(self._stringio, name)\n\n\nclass MockBuffer:\n    def __init__(self, inputs: str):\n        self.inputs = inputs.encode(\"utf-8\")  # Convert to bytes\n\n    def read(self, *args):\n        # Return as byte strings that can be split\n        return self.inputs\n\n    def readline(self, *args):\n        return self.inputs.split(b\"\\n\")[0] + b\"\\n\"\n\n\ndef clean_if_name(code: str) -> str:\n    try:\n        astree = ast.parse(code)\n        last_block = astree.body[-1]\n        if isinstance(last_block, ast.If):\n            condition = last_block.test\n            if ast.unparse(condition).strip() == \"__name__ == '__main__'\":\n                code = (\n                    ast.unparse(astree.body[:-1]) + \"\\n\" + ast.unparse(last_block.body)  # type: ignore\n                )\n    except:\n        pass\n\n    return code\n\n\ndef make_function(code: str) -> str:\n    try:\n        import_stmts = []\n        all_other_stmts = []\n        astree = ast.parse(code)\n        for stmt in astree.body:\n            if isinstance(stmt, (ast.Import, ast.ImportFrom)):\n                import_stmts.append(stmt)\n            else:\n                all_other_stmts.append(stmt)\n\n        function_ast = ast.FunctionDef(\n            name=\"wrapped_function\",\n            args=ast.arguments(\n                posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]\n            ),\n            body=all_other_stmts,\n            decorator_list=[],\n            lineno=-1,\n        )\n        main_code = (\n            import_string\n            + \"\\n\"\n            + ast.unparse(import_stmts)\n            + \"\\n\"\n            + ast.unparse(function_ast)\n        )\n        return main_code\n    except Exception:\n        return code\n\n\ndef call_method(method, inputs):\n    if isinstance(inputs, list):\n        inputs = \"\\n\".join(inputs)\n\n    inputs_line_iterator = iter(inputs.split(\"\\n\"))\n\n    # Create custom stdin mock with buffer support\n    mock_stdin = MockStdinWithBuffer(inputs)\n\n    # sys.setrecursionlimit(10000)\n\n    # @patch('builtins.input', side_effect=inputs.split(\"\\n\"))\n    @patch(\"builtins.open\", mock_open(read_data=inputs))\n    @patch(\"sys.stdin\", mock_stdin)  # Use our custom mock instead of StringIO\n    @patch(\"sys.stdin.readline\", lambda *args: next(inputs_line_iterator))\n    @patch(\"sys.stdin.readlines\", lambda *args: inputs.split(\"\\n\"))\n    @patch(\"sys.stdin.read\", lambda *args: inputs)\n    # @patch('sys.stdout.write', print)\n    def _inner_call_method(_method):\n        try:\n            return _method()\n        except SystemExit:\n            pass\n        finally:\n            pass\n\n    return _inner_call_method(method)\n\n\ndef get_function(compiled_sol, fn_name: str):  # type: ignore\n    try:\n        assert hasattr(compiled_sol, fn_name)\n        return getattr(compiled_sol, fn_name)\n    except Exception:\n        return\n\n\ndef compile_code(code: str, timeout: int):\n    signal.alarm(timeout)\n    try:\n        tmp_sol = ModuleType(\"tmp_sol\", \"\")\n        exec(code, tmp_sol.__dict__)\n        if \"class Solution\" in code:\n            # leetcode wraps solutions in `Solution`\n            # this is a hack to check if it is leetcode solution or not\n            # currently livecodebench only supports LeetCode but\n            # else condition allows future extensibility to other platforms\n            compiled_sol = tmp_sol.Solution()\n        else:\n            # do nothing in the other case since function is accesible\n            compiled_sol = tmp_sol\n\n        assert compiled_sol is not None\n    finally:\n        signal.alarm(0)\n\n    return compiled_sol\n\n\ndef convert_line_to_decimals(line: str) -> tuple[bool, list[Decimal]]:\n    try:\n        decimal_line = [Decimal(elem) for elem in line.split()]\n    except:\n        return False, []\n    return True, decimal_line\n\n\ndef get_stripped_lines(val: str):\n    ## you don't want empty lines to add empty list after splitlines!\n    val = val.strip()\n\n    return [val_line.strip() for val_line in val.split(\"\\n\")]\n\n\ndef grade_call_based(\n    code: str, all_inputs: list, all_outputs: list, fn_name: str, timeout: int\n):\n    # call-based clean up logic\n    # need to wrap in try-catch logic after to catch the correct errors, but for now this is fine.\n    code = import_string + \"\\n\\n\" + code\n    compiled_sol = compile_code(code, timeout)\n\n    if compiled_sol is None:\n        return\n\n    method = get_function(compiled_sol, fn_name)\n\n    if method is None:\n        return\n\n    all_inputs = [\n        [json.loads(line) for line in inputs.split(\"\\n\")] for inputs in all_inputs\n    ]\n\n    all_outputs = [json.loads(output) for output in all_outputs]\n\n    total_execution = 0\n    all_results = []\n    for idx, (gt_inp, gt_out) in enumerate(zip(all_inputs, all_outputs)):\n        signal.alarm(timeout)\n        faulthandler.enable()\n        try:\n            # can lock here so time is useful\n            start = time.time()\n            prediction = method(*gt_inp)\n            total_execution += time.time() - start\n            signal.alarm(0)\n\n            # don't penalize model if it produces tuples instead of lists\n            # ground truth sequences are not tuples\n            if isinstance(prediction, tuple):\n                prediction = list(prediction)\n\n            tmp_result = prediction == gt_out\n\n            # handle floating point comparisons\n\n            all_results.append(tmp_result)\n\n            if not tmp_result:\n                return all_results, {\n                    \"output\": truncatefn(prediction),\n                    \"inputs\": truncatefn(gt_inp),\n                    \"expected\": truncatefn(gt_out),\n                    \"error_code\": -2,\n                    \"error_message\": \"Wrong Answer\",\n                }\n        except Exception as e:\n            signal.alarm(0)\n            if \"timeoutexception\" in repr(e).lower():\n                all_results.append(-3)\n                return all_results, {\n                    \"error\": repr(e),\n                    \"error_code\": -3,\n                    \"error_message\": \"Time Limit Exceeded\",\n                    \"inputs\": truncatefn(gt_inp),\n                    \"expected\": truncatefn(gt_out),\n                }\n            else:\n                all_results.append(-4)\n                return all_results, {\n                    \"error\": repr(e),\n                    \"error_code\": -4,\n                    \"error_message\": \"Runtime Error\",\n                    \"inputs\": truncatefn(gt_inp),\n                    \"expected\": truncatefn(gt_out),\n                }\n\n        finally:\n            signal.alarm(0)\n            faulthandler.disable()\n\n    return all_results, {\"execution time\": total_execution}\n\n\ndef grade_stdio(\n    code: str,\n    all_inputs: list,\n    all_outputs: list,\n    timeout: int,\n):\n    ## runtime doesn't interact well with __name__ == '__main__'\n    code = clean_if_name(code)\n\n    ## we wrap the given code inside another function\n    code = make_function(code)\n\n    compiled_sol = compile_code(code, timeout)\n    if compiled_sol is None:\n        return\n\n    method = get_function(compiled_sol, \"wrapped_function\")\n\n    if method is None:\n        return\n\n    all_results = []\n    total_execution_time = 0\n    for idx, (gt_inp, gt_out) in enumerate(zip(all_inputs, all_outputs)):\n        signal.alarm(timeout)\n        faulthandler.enable()\n\n        signal.alarm(timeout)\n        with Capturing() as captured_output:\n            try:\n                start = time.time()\n                call_method(method, gt_inp)\n                total_execution_time += time.time() - start\n                # reset the alarm\n                signal.alarm(0)\n            except Exception as e:\n                signal.alarm(0)\n                if \"timeoutexception\" in repr(e).lower():\n                    all_results.append(-3)\n                    return all_results, {\n                        \"error\": repr(e),\n                        \"error_code\": -3,\n                        \"error_message\": \"Time Limit Exceeded\",\n                        \"inputs\": truncatefn(gt_inp),\n                        \"expected\": truncatefn(gt_out),\n                    }\n                else:\n                    all_results.append(-4)\n                    return all_results, {\n                        \"error\": repr(e),\n                        \"error_code\": -4,\n                        \"error_message\": \"Runtime Error\",\n                        \"inputs\": truncatefn(gt_inp),\n                        \"expected\": truncatefn(gt_out),\n                    }\n\n            finally:\n                signal.alarm(0)\n                faulthandler.disable()\n\n        prediction = captured_output[0]\n\n        stripped_prediction_lines = get_stripped_lines(prediction)\n        stripped_gt_out_lines = get_stripped_lines(gt_out)\n\n        ## WA happens in multiple circumstances\n        ## so cache the return to make it clean!\n        WA_send_args = {\n            \"output\": truncatefn(prediction),\n            \"inputs\": truncatefn(gt_inp),\n            \"expected\": truncatefn(gt_out),\n            \"error_code\": -2,\n        }\n\n        if len(stripped_prediction_lines) != len(stripped_gt_out_lines):\n            all_results.append(-2)\n            WA_send_args[\"error_message\"] = \"Wrong answer: mismatched output length\"\n            return all_results, WA_send_args\n\n        for output_line_idx, (\n            stripped_prediction_line,\n            stripped_gt_out_line,\n        ) in enumerate(zip(stripped_prediction_lines, stripped_gt_out_lines)):\n            WA_send_args[\"error_message\"] = (\n                f\"Wrong answer at {output_line_idx=}: {truncatefn(stripped_prediction_line)} != {truncatefn(stripped_gt_out_line)}\"\n            )\n\n            ## CASE 1: exact match\n            if stripped_prediction_line == stripped_gt_out_line:\n                continue\n\n            ## CASE 2: element-wise comparision\n            ## if there are floating elements\n            ## use `decimal` library for good floating point comparision\n            ## otherwise gotcha: np.isclose(50000000000000000, 50000000000000001) = True\n            ## note that we should always be able to convert to decimals\n\n            success, decimal_prediction_line = convert_line_to_decimals(\n                stripped_prediction_line\n            )\n            if not success:\n                all_results.append(-2)\n                return all_results, WA_send_args\n            success, decimal_gtout_line = convert_line_to_decimals(stripped_gt_out_line)\n            if not success:\n                all_results.append(-2)\n                return all_results, WA_send_args\n\n            if decimal_prediction_line == decimal_gtout_line:\n                continue\n\n            all_results.append(-2)\n            return all_results, WA_send_args\n        all_results.append(True)\n\n    return all_results, {\"execution time\": total_execution_time}\n\n\ndef run_test(sample, test=None, debug=False, timeout=6):\n    \"\"\"\n    if test(generated_code) is not None it'll try to run the code.\n    otherwise it'll just return an input and output pair.\n    \"\"\"\n    signal.signal(signal.SIGALRM, timeout_handler)\n\n    # Disable functionalities that can make destructive changes to the test.\n    # max memory is set to 4GB\n    reliability_guard()\n\n    if debug:\n        print(f\"start = {datetime.now().time()}\")\n\n    try:\n        in_outs = json.loads(sample[\"input_output\"])\n    except ValueError as e:\n        raise e\n        in_outs = None\n\n    if in_outs:\n        if in_outs.get(\"fn_name\") is None:\n            which_type = CODE_TYPE.standard_input  # Standard input\n            method_name = None\n\n        else:\n            which_type = CODE_TYPE.call_based  # Call-based\n            method_name = in_outs[\"fn_name\"]\n\n    if debug:\n        print(f\"loaded input_output = {datetime.now().time()}\")\n\n    if test is None:\n        assert False, \"should not happen: test code is none\"\n        return in_outs, {\"error\": \"no test code provided\"}\n    elif test is not None:\n        results = []\n        sol = import_string\n        if debug:\n            print(f\"loading test code = {datetime.now().time()}\")\n\n        if which_type == CODE_TYPE.call_based:\n            signal.alarm(timeout)\n            try:\n                results, metadata = grade_call_based(\n                    code=test,\n                    all_inputs=in_outs[\"inputs\"],\n                    all_outputs=in_outs[\"outputs\"],\n                    fn_name=method_name,\n                    timeout=timeout,\n                )\n                return results, metadata\n            except Exception as e:\n                return [-4], {\n                    \"error_code\": -4,\n                    \"error_message\": f\"Error during testing: {e}\",\n                }\n            finally:\n                signal.alarm(0)\n        elif which_type == CODE_TYPE.standard_input:\n            # sol\n            # if code has if __name__ == \"__main__\": then remove it\n\n            signal.alarm(timeout)\n            try:\n                results, metadata = grade_stdio(\n                    code=test,\n                    all_inputs=in_outs[\"inputs\"],\n                    all_outputs=in_outs[\"outputs\"],\n                    timeout=timeout,\n                )\n                return results, metadata\n            except Exception as e:\n                return [-4], {\n                    \"error_code\": -4,\n                    \"error_message\": f\"Error during testing: {e}\",\n                }\n            finally:\n                signal.alarm(0)\n\n\ndef reliability_guard(maximum_memory_bytes=None):\n    \"\"\"\n    This disables various destructive functions and prevents the generated code\n    from interfering with the test (e.g. fork bomb, killing other processes,\n    removing filesystem files, etc.)\n    WARNING\n    This function is NOT a security sandbox. Untrusted code, including, model-\n    generated code, should not be blindly executed outside of one. See the\n    Codex paper for more information about OpenAI's code sandbox, and proceed\n    with caution.\n    \"\"\"\n\n    if maximum_memory_bytes is not None:\n        import resource\n\n        resource.setrlimit(\n            resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)\n        )\n        resource.setrlimit(\n            resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)\n        )\n        if not platform.uname().system == \"Darwin\":\n            resource.setrlimit(\n                resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)\n            )\n\n    faulthandler.disable()\n\n    import builtins\n\n    # builtins.exit = None\n    builtins.quit = None\n\n    import os\n\n    os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n\n    os.kill = None\n    os.system = None\n    os.putenv = None\n    os.remove = None\n    os.removedirs = None\n    os.rmdir = None\n    os.fchdir = None\n    os.setuid = None\n    os.fork = None\n    os.forkpty = None\n    os.killpg = None\n    os.rename = None\n    os.renames = None\n    os.truncate = None\n    os.replace = None\n    os.unlink = None\n    os.fchmod = None\n    os.fchown = None\n    os.chmod = None\n    os.chown = None\n    os.chroot = None\n    os.fchdir = None\n    os.lchflags = None\n    os.lchmod = None\n    os.lchown = None\n    os.getcwd = None\n    os.chdir = None\n\n    import shutil\n\n    shutil.rmtree = None\n    shutil.move = None\n    shutil.chown = None\n\n    import subprocess\n\n    subprocess.Popen = None\n\n    __builtins__[\"help\"] = None\n\n    import sys\n\n    sys.modules[\"ipdb\"] = None\n    sys.modules[\"joblib\"] = None\n    sys.modules[\"resource\"] = None\n    sys.modules[\"psutil\"] = None\n    sys.modules[\"tkinter\"] = None\n"
  },
  {
    "path": "dashboard/dashboard.nix",
    "content": "{ lib\n, config\n, dream2nix\n, ...\n}:\nlet\n  # Read and parse the lock file\n  rawLockFile = builtins.fromJSON (builtins.readFile \"${config.deps.dashboardSrc}/package-lock.json\");\n\n  # For packages with bundleDependencies, filter out deps that are bundled\n  # (bundled deps are inside the tarball, not separate lockfile entries)\n  fixedPackages = lib.mapAttrs\n    (path: entry:\n      if entry ? bundleDependencies && entry.bundleDependencies != [ ]\n      then entry // {\n        dependencies = lib.filterAttrs\n          (name: _: !(lib.elem name entry.bundleDependencies))\n          (entry.dependencies or { });\n      }\n      else entry\n    )\n    (rawLockFile.packages or { });\n\n  fixedLockFile = rawLockFile // { packages = fixedPackages; };\nin\n{\n  imports = [\n    dream2nix.modules.dream2nix.nodejs-package-lock-v3\n    dream2nix.modules.dream2nix.nodejs-granular-v3\n  ];\n\n  name = \"exo-dashboard\";\n  version = \"1.0.0\";\n\n  mkDerivation = {\n    src = config.deps.dashboardSrc;\n\n    buildPhase = ''\n      runHook preBuild\n      npm run build\n      runHook postBuild\n    '';\n\n    installPhase = ''\n      runHook preInstall\n      cp -r build $out/build\n      runHook postInstall\n    '';\n  };\n\n  deps = { nixpkgs, ... }: {\n    inherit (nixpkgs) stdenv;\n    dashboardSrc = null; # Injected by parts.nix\n  };\n\n  nodejs-package-lock-v3 = {\n    # Don't use packageLockFile - provide the fixed lock content directly\n    packageLock = fixedLockFile;\n  };\n}\n"
  },
  {
    "path": "dashboard/package.json",
    "content": "{\n\t\"name\": \"exo-dashboard\",\n\t\"private\": true,\n\t\"version\": \"1.0.0\",\n\t\"type\": \"module\",\n\t\"scripts\": {\n\t\t\"dev\": \"vite dev\",\n\t\t\"build\": \"vite build\",\n\t\t\"preview\": \"vite preview\",\n\t\t\"prepare\": \"svelte-kit sync || echo ''\",\n\t\t\"check\": \"svelte-kit sync && svelte-check --tsconfig ./tsconfig.json\"\n\t},\n\t\"devDependencies\": {\n\t\t\"prettier\": \"^3.4.2\",\n\t\t\"prettier-plugin-svelte\": \"^3.3.3\",\n\t\t\"@sveltejs/adapter-static\": \"^3.0.10\",\n\t\t\"@sveltejs/kit\": \"^2.48.4\",\n\t\t\"@sveltejs/vite-plugin-svelte\": \"^5.0.0\",\n\t\t\"@tailwindcss/vite\": \"^4.0.0\",\n\t\t\"@types/d3\": \"^7.4.3\",\n\t\t\"@types/node\": \"^22\",\n\t\t\"d3\": \"^7.9.0\",\n\t\t\"svelte\": \"^5.0.0\",\n\t\t\"svelte-check\": \"^4.0.0\",\n\t\t\"tailwindcss\": \"^4.0.0\",\n\t\t\"tw-animate-css\": \"^1.3.5\",\n\t\t\"typescript\": \"^5.0.0\",\n\t\t\"vite\": \"^6.0.0\"\n\t},\n\t\"dependencies\": {\n\t\t\"highlight.js\": \"^11.11.1\",\n\t\t\"katex\": \"^0.16.27\",\n\t\t\"marked\": \"^17.0.1\",\n\t\t\"mode-watcher\": \"^1.1.0\"\n\t}\n}\n"
  },
  {
    "path": "dashboard/parts.nix",
    "content": "{ inputs, ... }:\n{\n  perSystem =\n    { pkgs, lib, ... }:\n    let\n      # Filter source to ONLY include package.json and package-lock.json\n      # This ensures prettier-svelte only rebuilds when lockfiles change\n      dashboardLockfileSrc = lib.cleanSourceWith {\n        src = inputs.self;\n        filter =\n          path: type:\n          let\n            baseName = builtins.baseNameOf path;\n            isDashboardDir = baseName == \"dashboard\" && type == \"directory\";\n            isPackageFile =\n              (lib.hasInfix \"/dashboard/\" path || lib.hasSuffix \"/dashboard\" (builtins.dirOf path))\n              && (baseName == \"package.json\" || baseName == \"package-lock.json\");\n          in\n          isDashboardDir || isPackageFile;\n      };\n\n      # Stub source with lockfiles and minimal files for build to succeed\n      # This allows prettier-svelte to avoid rebuilding when dashboard source changes\n      dashboardStubSrc = pkgs.runCommand \"dashboard-stub-src\" { } ''\n        mkdir -p $out\n        cp ${dashboardLockfileSrc}/dashboard/package.json $out/\n        cp ${dashboardLockfileSrc}/dashboard/package-lock.json $out/\n        # Minimal files so vite build succeeds (produces empty output)\n        echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html\n        mkdir -p $out/src\n        touch $out/src/app.html\n      '';\n\n      # Deps-only build using stub source (for prettier-svelte)\n      # Only rebuilds when package.json or package-lock.json change\n      dashboardDeps = inputs.dream2nix.lib.evalModules {\n        packageSets.nixpkgs = pkgs;\n        modules = [\n          ./dashboard.nix\n          {\n            paths.projectRoot = inputs.self;\n            paths.projectRootFile = \"flake.nix\";\n            paths.package = inputs.self + \"/dashboard\";\n          }\n          {\n            deps.dashboardSrc = lib.mkForce dashboardStubSrc;\n          }\n          # Override build phases to skip the actual build - just need node_modules\n          {\n            mkDerivation = {\n              buildPhase = lib.mkForce \"true\";\n              installPhase = lib.mkForce ''\n                runHook preInstall\n                runHook postInstall\n              '';\n            };\n          }\n        ];\n      };\n\n      # Filter source to only include dashboard directory\n      dashboardSrc = lib.cleanSourceWith {\n        src = inputs.self;\n        filter =\n          path: type:\n          let\n            baseName = builtins.baseNameOf path;\n            inDashboardDir =\n              (lib.hasInfix \"/dashboard/\" path)\n              || (lib.hasSuffix \"/dashboard\" (builtins.dirOf path))\n              || (baseName == \"dashboard\" && type == \"directory\");\n          in\n          inDashboardDir;\n      };\n\n      # Build the dashboard with dream2nix (includes node_modules in output)\n      dashboardFull = inputs.dream2nix.lib.evalModules {\n        packageSets.nixpkgs = pkgs;\n        modules = [\n          ./dashboard.nix\n          {\n            paths.projectRoot = inputs.self;\n            paths.projectRootFile = \"flake.nix\";\n            paths.package = inputs.self + \"/dashboard\";\n          }\n          # Inject the filtered source\n          {\n            deps.dashboardSrc = lib.mkForce \"${dashboardSrc}/dashboard\";\n          }\n        ];\n      };\n    in\n    {\n      # Extract just the static site from the full build\n      packages.dashboard = pkgs.runCommand \"exo-dashboard\" { } ''\n        cp -r ${dashboardFull}/build $out\n      '';\n\n      # Prettier with svelte plugin for treefmt\n      # Uses dashboardDeps instead of dashboardFull to avoid rebuilding on source changes\n      packages.prettier-svelte = pkgs.writeShellScriptBin \"prettier-svelte\" ''\n        export NODE_PATH=\"${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules\"\n        exec ${pkgs.nodejs}/bin/node \\\n          ${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \\\n          --plugin \"${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js\" \\\n          \"$@\"\n      '';\n    };\n}\n"
  },
  {
    "path": "dashboard/src/app.css",
    "content": "@import 'tailwindcss';\n@import 'tw-animate-css';\n\n@custom-variant dark (&:is(.dark *));\n\n:root {\n\t/* EXO Brand Colors - Command Center Theme (neutral dark greys) */\n\t--exo-black: oklch(0.12 0 0);\n\t--exo-dark-gray: oklch(0.16 0 0);\n\t--exo-medium-gray: oklch(0.22 0 0);\n\t--exo-light-gray: oklch(0.6 0 0);\n\t--exo-yellow: oklch(0.85 0.18 85);\n\t--exo-yellow-darker: oklch(0.7 0.16 85);\n\t--exo-yellow-glow: oklch(0.9 0.2 85);\n\t\n\t/* Gotham-inspired accent colors */\n\t--exo-grid: oklch(0.25 0 0);\n\t--exo-scanline: oklch(0.15 0 0);\n\t--exo-glow-yellow: 0 0 20px oklch(0.85 0.18 85 / 0.3);\n\t--exo-glow-yellow-strong: 0 0 40px oklch(0.85 0.18 85 / 0.5);\n\t\n\t/* Theme Variables */\n\t--radius: 0.375rem;\n\t--background: var(--exo-black);\n\t--foreground: oklch(0.9 0 0);\n\t--card: var(--exo-dark-gray);\n\t--card-foreground: oklch(0.9 0 0);\n\t--popover: var(--exo-dark-gray);\n\t--popover-foreground: oklch(0.9 0 0);\n\t--primary: var(--exo-yellow);\n\t--primary-foreground: var(--exo-black);\n\t--secondary: var(--exo-medium-gray);\n\t--secondary-foreground: oklch(0.9 0 0);\n\t--muted: var(--exo-medium-gray);\n\t--muted-foreground: var(--exo-light-gray);\n\t--accent: var(--exo-medium-gray);\n\t--accent-foreground: oklch(0.9 0 0);\n\t--destructive: oklch(0.6 0.25 25);\n\t--border: oklch(0.22 0 0);\n\t--input: oklch(0.22 0 0);\n\t--ring: var(--exo-yellow);\n}\n\n@theme inline {\n\t--radius-sm: calc(var(--radius) - 2px);\n\t--radius-md: var(--radius);\n\t--radius-lg: calc(var(--radius) + 2px);\n\t--radius-xl: calc(var(--radius) + 4px);\n\t--color-background: var(--background);\n\t--color-foreground: var(--foreground);\n\t--color-card: var(--card);\n\t--color-card-foreground: var(--card-foreground);\n\t--color-popover: var(--popover);\n\t--color-popover-foreground: var(--popover-foreground);\n\t--color-primary: var(--primary);\n\t--color-primary-foreground: var(--primary-foreground);\n\t--color-secondary: var(--secondary);\n\t--color-secondary-foreground: var(--secondary-foreground);\n\t--color-muted: var(--muted);\n\t--color-muted-foreground: var(--muted-foreground);\n\t--color-accent: var(--accent);\n\t--color-accent-foreground: var(--accent-foreground);\n\t--color-destructive: var(--destructive);\n\t--color-border: var(--border);\n\t--color-input: var(--input);\n\t--color-ring: var(--ring);\n\t\n\t/* Custom EXO colors */\n\t--color-exo-yellow: var(--exo-yellow);\n\t--color-exo-yellow-darker: var(--exo-yellow-darker);\n\t--color-exo-black: var(--exo-black);\n\t--color-exo-dark-gray: var(--exo-dark-gray);\n\t--color-exo-medium-gray: var(--exo-medium-gray);\n\t--color-exo-light-gray: var(--exo-light-gray);\n}\n\n@layer base {\n\t* {\n\t\t@apply border-border outline-ring/50;\n\t}\n\thtml, body {\n\t\t@apply bg-background text-foreground;\n\t\tfont-family: 'SF Mono', 'Fira Code', 'Monaco', 'Consolas', 'Liberation Mono', monospace;\n\t\tletter-spacing: 0.02em;\n\t}\n}\n\n@layer utilities {\n\t.scrollbar-hide {\n\t\t&::-webkit-scrollbar {\n\t\t\tdisplay: none;\n\t\t}\n\t\t-ms-overflow-style: none;\n\t\tscrollbar-width: none;\n\t}\n\t\n\t/* CRT Scanline effect */\n\t.scanlines {\n\t\tposition: relative;\n\t\t&::before {\n\t\t\tcontent: '';\n\t\t\tposition: absolute;\n\t\t\tinset: 0;\n\t\t\tbackground: repeating-linear-gradient(\n\t\t\t\t0deg,\n\t\t\t\ttransparent,\n\t\t\t\ttransparent 2px,\n\t\t\t\toklch(0 0 0 / 0.03) 2px,\n\t\t\t\toklch(0 0 0 / 0.03) 4px\n\t\t\t);\n\t\t\tpointer-events: none;\n\t\t\tz-index: 100;\n\t\t}\n\t}\n\t\n\t/* Command panel styling */\n\t.command-panel {\n\t\tbackground: linear-gradient(\n\t\t\t180deg,\n\t\t\toklch(0.16 0 0 / 0.95) 0%,\n\t\t\toklch(0.12 0 0 / 0.98) 100%\n\t\t);\n\t\tborder: 1px solid oklch(0.25 0 0);\n\t\tbox-shadow: \n\t\t\tinset 0 1px 0 oklch(1 0 0 / 0.03),\n\t\t\t0 4px 20px oklch(0 0 0 / 0.5);\n\t}\n\t\n\t/* Glow text */\n\t.glow-text {\n\t\ttext-shadow: \n\t\t\t0 0 10px oklch(0.85 0.18 85 / 0.5),\n\t\t\t0 0 20px oklch(0.85 0.18 85 / 0.3),\n\t\t\t0 0 40px oklch(0.85 0.18 85 / 0.1);\n\t}\n\t\n\t/* Status indicator pulse */\n\t.status-pulse {\n\t\tanimation: statusPulse 2s ease-in-out infinite;\n\t}\n\t\n\t/* Grid background */\n\t.grid-bg {\n\t\tbackground-image: \n\t\t\tlinear-gradient(oklch(0.2 0 0 / 0.3) 1px, transparent 1px),\n\t\t\tlinear-gradient(90deg, oklch(0.2 0 0 / 0.3) 1px, transparent 1px);\n\t\tbackground-size: 40px 40px;\n\t}\n}\n\n/* Animations */\n@keyframes flowAnimation {\n\tfrom {\n\t\tstroke-dashoffset: 0;\n\t}\n\tto {\n\t\tstroke-dashoffset: -16;\n\t}\n}\n\n@keyframes statusPulse {\n\t0%, 100% {\n\t\topacity: 1;\n\t}\n\t50% {\n\t\topacity: 0.5;\n\t}\n}\n\n@keyframes radarSweep {\n\tfrom {\n\t\ttransform: rotate(0deg);\n\t}\n\tto {\n\t\ttransform: rotate(360deg);\n\t}\n}\n\n@keyframes glowPulse {\n\t0%, 100% {\n\t\tbox-shadow: 0 0 5px oklch(0.85 0.18 85 / 0.3), 0 0 10px oklch(0.85 0.18 85 / 0.1);\n\t}\n\t50% {\n\t\tbox-shadow: 0 0 15px oklch(0.85 0.18 85 / 0.5), 0 0 30px oklch(0.85 0.18 85 / 0.2);\n\t}\n}\n\n@keyframes dataPulse {\n\t0%, 100% {\n\t\topacity: 0.6;\n\t}\n\t50% {\n\t\topacity: 1;\n\t}\n}\n\n.graph-link {\n\tstroke: oklch(0.85 0.18 85 / 0.4);\n\tstroke-width: 1.5px;\n\tstroke-dasharray: 8, 8;\n\tanimation: flowAnimation 1s linear infinite;\n\tfilter: drop-shadow(0 0 3px oklch(0.85 0.18 85 / 0.5));\n}\n\n/* Onboarding step 2: connection line between devices */\n.onboarding-connection-line {\n\tstroke: oklch(0.85 0.18 85 / 0.5);\n\tstroke-width: 1.5px;\n\tstroke-dasharray: 6, 6;\n\tanimation: flowAnimation 1s linear infinite;\n\tfilter: drop-shadow(0 0 4px oklch(0.85 0.18 85 / 0.4));\n}\n\n/* Onboarding step 4: red connection line for disconnect */\n.onboarding-connection-line-red {\n\tstroke: rgba(220, 38, 38, 0.7);\n\tstroke-width: 1.5px;\n\tstroke-dasharray: 6, 6;\n\tfilter: drop-shadow(0 0 2px rgba(220, 38, 38, 0.3));\n}\n\n.graph-link-active {\n\tstroke: oklch(0.85 0.18 85 / 0.8);\n\tstroke-width: 2px;\n\tfilter: drop-shadow(0 0 6px oklch(0.85 0.18 85 / 0.8));\n}\n\n/* CRT Screen effect for topology */\n.crt-screen {\n\tposition: relative;\n\tborder-radius: 50%;\n\tbackground: radial-gradient(\n\t\tellipse at center,\n\t\toklch(0.16 0 0) 0%,\n\t\toklch(0.12 0 0) 50%,\n\t\toklch(0.09 0 0) 100%\n\t);\n\tbox-shadow:\n\t\tinset 0 0 100px oklch(0 0 0 / 0.5),\n\t\t0 0 50px oklch(0.85 0.18 85 / 0.1);\n}\n\n/* Data readout styling */\n.data-readout {\n\tfont-family: 'SF Mono', 'Fira Code', monospace;\n\tfont-size: 11px;\n\tletter-spacing: 0.05em;\n\ttext-transform: uppercase;\n}\n\n/* Terminal cursor blink */\n.cursor-blink {\n\tanimation: cursorBlink 1s step-end infinite;\n}\n\n@keyframes cursorBlink {\n\t0%, 100% { opacity: 1; }\n\t50% { opacity: 0; }\n}\n\n/* Custom scrollbar for command center */\n::-webkit-scrollbar {\n\twidth: 6px;\n\theight: 6px;\n}\n\n::-webkit-scrollbar-track {\n\tbackground: oklch(0.1 0 0);\n}\n\n::-webkit-scrollbar-thumb {\n\tbackground: oklch(0.3 0 0);\n\tborder-radius: 3px;\n}\n\n::-webkit-scrollbar-thumb:hover {\n\tbackground: oklch(0.85 0.18 85 / 0.5);\n}\n\n/* Remove focus outline/border for inputs */\ninput:focus, textarea:focus {\n\toutline: none;\n\tbox-shadow: none;\n}\n\n/* Shooting Stars Animation */\n.shooting-stars {\n\tposition: fixed;\n\tinset: 0;\n\toverflow: hidden;\n\tpointer-events: none;\n\tz-index: 0;\n}\n\n.shooting-star {\n\tposition: absolute;\n\twidth: 3px;\n\theight: 3px;\n\tbackground: oklch(0.85 0.18 85 / 1);\n\tborder-radius: 50%;\n\tbox-shadow: 0 0 6px oklch(0.85 0.18 85 / 0.8);\n\tanimation: shootingStar var(--duration, 3s) linear infinite;\n\tanimation-delay: var(--delay, 0s);\n\topacity: 0;\n}\n\n.shooting-star::before {\n\tcontent: '';\n\tposition: absolute;\n\twidth: 80px;\n\theight: 2px;\n\tbackground: linear-gradient(90deg, oklch(0.85 0.18 85 / 0), oklch(0.85 0.18 85 / 0.6));\n\ttransform: rotate(45deg);\n\ttransform-origin: right center;\n\ttop: 0;\n\tright: 2px;\n}\n\n@keyframes shootingStar {\n\t0% {\n\t\topacity: 0;\n\t\ttransform: translate(0, 0);\n\t}\n\t0.5% {\n\t\topacity: 1;\n\t}\n\t2.5% {\n\t\topacity: 0.8;\n\t\ttransform: translate(300px, 300px);\n\t}\n\t3.5% {\n\t\topacity: 0;\n\t\ttransform: translate(400px, 400px);\n\t}\n\t100% {\n\t\topacity: 0;\n\t\ttransform: translate(400px, 400px);\n\t}\n}\n\n/* Onboarding smooth fade-in */\n@keyframes onb-fade-in {\n\tfrom { opacity: 0; transform: translateY(8px); }\n\tto { opacity: 1; transform: none; }\n}\n\n/* Opacity-only fade-in — no transform, safe for containers with fixed-position children */\n@keyframes onb-fade-opacity {\n\tfrom { opacity: 0; }\n\tto { opacity: 1; }\n}\n\n/* Respect reduced motion preference */\n@media (prefers-reduced-motion: reduce) {\n\t.shooting-star,\n\t.shooting-star::before {\n\t\tanimation: none !important;\n\t\topacity: 0 !important;\n\t}\n\t.graph-link {\n\t\tanimation: none;\n\t}\n\t.status-pulse {\n\t\tanimation: none;\n\t}\n\t.cursor-blink {\n\t\tanimation: none;\n\t}\n\t.onboarding-connection-line {\n\t\tanimation: none;\n\t}\n\t.onboarding-connection-line-red {\n\t\tanimation: none;\n\t}\n\t[style*=\"onb-fade-in\"],\n\t[style*=\"onb-fade-opacity\"] {\n\t\tanimation: none !important;\n\t\topacity: 1 !important;\n\t}\n\t*,\n\t*::before,\n\t*::after {\n\t\ttransition-duration: 0.01ms !important;\n\t\tanimation-duration: 0.01ms !important;\n\t\tanimation-iteration-count: 1 !important;\n\t}\n}\n"
  },
  {
    "path": "dashboard/src/app.d.ts",
    "content": "// See https://svelte.dev/docs/kit/types#app.d.ts\n// for information about these interfaces\ndeclare global {\n  namespace App {\n    // interface Error {}\n    // interface Locals {}\n    // interface PageData {}\n    // interface PageState {}\n    // interface Platform {}\n  }\n}\n\nexport {};\n"
  },
  {
    "path": "dashboard/src/app.html",
    "content": "<!doctype html>\n<html lang=\"en\">\n\t<head>\n\t\t<meta charset=\"utf-8\" />\n\t\t<link rel=\"icon\" href=\"%sveltekit.assets%/favicon.ico\" />\n\t\t<meta name=\"viewport\" content=\"width=device-width, initial-scale=1\" />\n\t\t<title>EXO</title>\n\t\t%sveltekit.head%\n\t</head>\n\t<body data-sveltekit-preload-data=\"hover\">\n\t\t<div style=\"display: contents\">%sveltekit.body%</div>\n\t</body>\n</html>\n\n"
  },
  {
    "path": "dashboard/src/lib/components/ChatAttachments.svelte",
    "content": "<script lang=\"ts\">\n  import type { ChatUploadedFile } from \"$lib/types/files\";\n  import { formatFileSize, getFileCategory } from \"$lib/types/files\";\n\n  interface Props {\n    files: ChatUploadedFile[];\n    readonly?: boolean;\n    onRemove?: (fileId: string) => void;\n  }\n\n  let { files, readonly = false, onRemove }: Props = $props();\n\n  function getFileIcon(file: ChatUploadedFile): string {\n    const category = getFileCategory(file.type, file.name);\n    switch (category) {\n      case \"image\":\n        return \"🖼\";\n      case \"text\":\n        return \"📄\";\n      case \"pdf\":\n        return \"📑\";\n      case \"audio\":\n        return \"🎵\";\n      default:\n        return \"📎\";\n    }\n  }\n\n  function truncateName(name: string, maxLen: number = 20): string {\n    if (name.length <= maxLen) return name;\n    const ext = name.slice(name.lastIndexOf(\".\"));\n    const base = name.slice(0, name.lastIndexOf(\".\"));\n    const available = maxLen - ext.length - 3;\n    return base.slice(0, available) + \"...\" + ext;\n  }\n</script>\n\n{#if files.length > 0}\n  <div class=\"flex flex-wrap gap-2 mb-3 px-1\">\n    {#each files as file (file.id)}\n      <div\n        class=\"group relative flex items-center gap-2 bg-exo-dark-gray/80 border border-exo-yellow/30 rounded px-2.5 py-1.5 text-xs font-mono transition-all hover:border-exo-yellow/50 hover:shadow-[0_0_10px_rgba(255,215,0,0.1)]\"\n      >\n        <!-- File preview or icon -->\n        {#if file.preview && getFileCategory(file.type, file.name) === \"image\"}\n          <img\n            src={file.preview}\n            alt={file.name}\n            class=\"w-8 h-8 object-cover rounded border border-exo-yellow/20\"\n          />\n        {:else}\n          <span class=\"text-base\">{getFileIcon(file)}</span>\n        {/if}\n\n        <!-- File info -->\n        <div class=\"flex flex-col min-w-0\">\n          <span\n            class=\"text-exo-yellow truncate max-w-[120px]\"\n            title={file.name}\n          >\n            {truncateName(file.name)}\n          </span>\n          <span class=\"text-exo-light-gray text-xs\">\n            {formatFileSize(file.size)}\n          </span>\n        </div>\n\n        <!-- Remove button -->\n        {#if !readonly && onRemove}\n          <button\n            type=\"button\"\n            onclick={() => onRemove?.(file.id)}\n            class=\"ml-1 w-4 h-4 flex items-center justify-center text-exo-light-gray hover:text-red-400 transition-colors cursor-pointer\"\n            title=\"Remove file\"\n          >\n            <svg\n              class=\"w-3 h-3\"\n              fill=\"none\"\n              viewBox=\"0 0 24 24\"\n              stroke=\"currentColor\"\n            >\n              <path\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n                stroke-width=\"2\"\n                d=\"M6 18L18 6M6 6l12 12\"\n              />\n            </svg>\n          </button>\n        {/if}\n      </div>\n    {/each}\n  </div>\n{/if}\n"
  },
  {
    "path": "dashboard/src/lib/components/ChatForm.svelte",
    "content": "<script lang=\"ts\">\n  import {\n    isLoading,\n    editingImage,\n    clearEditingImage,\n    selectedChatModel,\n    ttftMs,\n    tps,\n    totalTokens,\n    thinkingEnabled as thinkingEnabledStore,\n    setConversationThinking,\n    stopGeneration,\n  } from \"$lib/stores/app.svelte\";\n  import ChatAttachments from \"./ChatAttachments.svelte\";\n  import ImageParamsPanel from \"./ImageParamsPanel.svelte\";\n  import type { ChatUploadedFile } from \"$lib/types/files\";\n  import { processUploadedFiles, getAcceptString } from \"$lib/types/files\";\n\n  interface Props {\n    class?: string;\n    placeholder?: string;\n    showHelperText?: boolean;\n    autofocus?: boolean;\n    showModelSelector?: boolean;\n    modelTasks?: Record<string, string[]>;\n    modelCapabilities?: Record<string, string[]>;\n    onSend?: () => void;\n    onAutoSend: (\n      content: string,\n      files?: {\n        id: string;\n        name: string;\n        type: string;\n        textContent?: string;\n        preview?: string;\n      }[],\n    ) => void;\n    onOpenModelPicker?: () => void;\n    modelDisplayOverride?: string;\n  }\n\n  let {\n    class: className = \"\",\n    placeholder = \"Ask anything\",\n    showHelperText = false,\n    autofocus = true,\n    showModelSelector = false,\n    modelTasks = {},\n    modelCapabilities = {},\n    onSend,\n    onAutoSend,\n    onOpenModelPicker,\n    modelDisplayOverride,\n  }: Props = $props();\n\n  let message = $state(\"\");\n  let textareaRef: HTMLTextAreaElement | undefined = $state();\n  let fileInputRef: HTMLInputElement | undefined = $state();\n  let uploadedFiles = $state<ChatUploadedFile[]>([]);\n  let isDragOver = $state(false);\n  const thinkingEnabled = $derived(thinkingEnabledStore());\n  let loading = $derived(isLoading());\n  const currentModel = $derived(selectedChatModel());\n  const currentTtft = $derived(ttftMs());\n  const currentTps = $derived(tps());\n  const currentTokens = $derived(totalTokens());\n  const currentEditingImage = $derived(editingImage());\n  const isEditMode = $derived(currentEditingImage !== null);\n\n  // Accept all supported file types\n  const acceptString = getAcceptString([\"image\", \"text\", \"pdf\"]);\n\n  function modelSupportsImageGeneration(modelId: string): boolean {\n    const tasks = modelTasks[modelId] || [];\n    return tasks.includes(\"TextToImage\") || tasks.includes(\"ImageToImage\");\n  }\n\n  function modelSupportsTextToImage(modelId: string): boolean {\n    const tasks = modelTasks[modelId] || [];\n    return tasks.includes(\"TextToImage\");\n  }\n\n  function modelSupportsOnlyImageEditing(modelId: string): boolean {\n    const tasks = modelTasks[modelId] || [];\n    return tasks.includes(\"ImageToImage\") && !tasks.includes(\"TextToImage\");\n  }\n\n  function modelSupportsImageEditing(modelId: string): boolean {\n    const tasks = modelTasks[modelId] || [];\n    return tasks.includes(\"ImageToImage\");\n  }\n\n  const isImageModel = $derived(() => {\n    if (!currentModel) return false;\n    return (\n      modelSupportsTextToImage(currentModel) ||\n      modelSupportsImageEditing(currentModel)\n    );\n  });\n\n  const modelSupportsThinking = $derived(() => {\n    if (!currentModel) return false;\n    const caps = modelCapabilities[currentModel] || [];\n    return caps.includes(\"thinking_toggle\") && caps.includes(\"text\");\n  });\n\n  const isEditOnlyWithoutImage = $derived(\n    currentModel !== null &&\n      modelSupportsOnlyImageEditing(currentModel) &&\n      !isEditMode &&\n      uploadedFiles.length === 0,\n  );\n\n  // Show edit mode when: explicit edit mode OR (model supports ImageToImage AND files attached)\n  const shouldShowEditMode = $derived(\n    isEditMode ||\n      (currentModel &&\n        modelSupportsImageEditing(currentModel) &&\n        uploadedFiles.length > 0),\n  );\n\n  // Short label for the currently selected model\n  const currentModelLabel = $derived(\n    currentModel\n      ? currentModel.split(\"/\").pop() || currentModel\n      : modelDisplayOverride\n        ? modelDisplayOverride.split(\"/\").pop() || modelDisplayOverride\n        : \"\",\n  );\n\n  async function handleFiles(files: File[]) {\n    if (files.length === 0) return;\n    const processed = await processUploadedFiles(files);\n    uploadedFiles = [...uploadedFiles, ...processed];\n  }\n\n  function handleFileInputChange(event: Event) {\n    const input = event.target as HTMLInputElement;\n    if (input.files && input.files.length > 0) {\n      handleFiles(Array.from(input.files));\n      input.value = \"\"; // Reset for next selection\n    }\n  }\n\n  function handleFileRemove(fileId: string) {\n    uploadedFiles = uploadedFiles.filter((f) => f.id !== fileId);\n  }\n\n  function handlePaste(event: ClipboardEvent) {\n    if (!event.clipboardData) return;\n\n    const files = Array.from(event.clipboardData.items)\n      .filter((item) => item.kind === \"file\")\n      .map((item) => item.getAsFile())\n      .filter((file): file is File => file !== null);\n\n    if (files.length > 0) {\n      event.preventDefault();\n      handleFiles(files);\n      return;\n    }\n\n    // Handle long text paste as file\n    const text = event.clipboardData.getData(\"text/plain\");\n    if (text.length > 2500) {\n      event.preventDefault();\n      const textFile = new File([text], \"pasted-text.txt\", {\n        type: \"text/plain\",\n      });\n      handleFiles([textFile]);\n    }\n  }\n\n  function handleDragOver(event: DragEvent) {\n    event.preventDefault();\n    isDragOver = true;\n  }\n\n  function handleDragLeave(event: DragEvent) {\n    event.preventDefault();\n    isDragOver = false;\n  }\n\n  function handleDrop(event: DragEvent) {\n    event.preventDefault();\n    isDragOver = false;\n\n    if (event.dataTransfer?.files) {\n      handleFiles(Array.from(event.dataTransfer.files));\n    }\n  }\n\n  function handleKeydown(event: KeyboardEvent) {\n    // Prevent form submission during IME composition (e.g., Chinese, Japanese, Korean input)\n    if (event.isComposing || event.keyCode === 229) {\n      return;\n    }\n\n    if (event.key === \"Enter\" && !event.shiftKey) {\n      event.preventDefault();\n      handleSubmit();\n    }\n  }\n\n  function handleSubmit() {\n    if ((!message.trim() && uploadedFiles.length === 0) || loading) return;\n    if (isEditOnlyWithoutImage) return;\n\n    const content = message.trim();\n    const files = [...uploadedFiles];\n\n    message = \"\";\n    uploadedFiles = [];\n    resetTextareaHeight();\n\n    // Parent controls all send logic (including image routing,\n    // launching non-running models before sending, etc.)\n    onAutoSend(content, files);\n    onSend?.();\n    setTimeout(() => textareaRef?.focus(), 10);\n  }\n\n  function handleInput() {\n    if (!textareaRef) return;\n    textareaRef.style.height = \"auto\";\n    textareaRef.style.height = Math.min(textareaRef.scrollHeight, 150) + \"px\";\n  }\n\n  function resetTextareaHeight() {\n    if (textareaRef) {\n      textareaRef.style.height = \"auto\";\n    }\n  }\n\n  function openFilePicker() {\n    fileInputRef?.click();\n  }\n\n  // Track previous loading state to detect when loading completes\n  let wasLoading = $state(false);\n\n  $effect(() => {\n    if (autofocus && textareaRef) {\n      setTimeout(() => textareaRef?.focus(), 10);\n    }\n  });\n\n  // Refocus after loading completes (AI response finished)\n  $effect(() => {\n    if (wasLoading && !loading && textareaRef) {\n      setTimeout(() => textareaRef?.focus(), 50);\n    }\n    wasLoading = loading;\n  });\n\n  const canSend = $derived(\n    message.trim().length > 0 || uploadedFiles.length > 0,\n  );\n</script>\n\n<!-- Hidden file input -->\n<input\n  bind:this={fileInputRef}\n  type=\"file\"\n  accept={acceptString}\n  multiple\n  class=\"hidden\"\n  onchange={handleFileInputChange}\n/>\n\n<form\n  onsubmit={(e) => {\n    e.preventDefault();\n    handleSubmit();\n  }}\n  class=\"w-full {className}\"\n  ondragover={handleDragOver}\n  ondragleave={handleDragLeave}\n  ondrop={handleDrop}\n>\n  <div\n    class=\"relative command-panel rounded overflow-hidden transition-all duration-200 {isDragOver\n      ? 'ring-2 ring-exo-yellow ring-opacity-50'\n      : ''}\"\n  >\n    <!-- Top accent line -->\n    <div\n      class=\"absolute top-0 left-0 right-0 h-px bg-gradient-to-r from-transparent via-exo-yellow/50 to-transparent\"\n    ></div>\n\n    <!-- Drag overlay -->\n    {#if isDragOver}\n      <div\n        class=\"absolute inset-0 bg-exo-dark-gray/80 z-10 flex items-center justify-center\"\n      >\n        <div class=\"text-exo-yellow text-sm font-mono tracking-wider uppercase\">\n          DROP FILES HERE\n        </div>\n      </div>\n    {/if}\n\n    <!-- Edit mode banner -->\n    {#if isEditMode && currentEditingImage}\n      <div\n        class=\"flex items-center gap-3 px-3 py-2 bg-exo-yellow/10 border-b border-exo-yellow/30\"\n      >\n        <img\n          src={currentEditingImage.imageDataUrl}\n          alt=\"Source for editing\"\n          class=\"w-10 h-10 object-cover rounded border border-exo-yellow/30\"\n        />\n        <div class=\"flex-1\">\n          <span\n            class=\"text-xs font-mono tracking-wider uppercase text-exo-yellow\"\n            >EDITING IMAGE</span\n          >\n        </div>\n        <button\n          type=\"button\"\n          onclick={() => clearEditingImage()}\n          class=\"px-2 py-1 text-xs font-mono tracking-wider uppercase bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50 rounded hover:bg-exo-medium-gray/50 hover:text-exo-yellow transition-colors cursor-pointer\"\n        >\n          CANCEL\n        </button>\n      </div>\n    {/if}\n\n    <!-- Model selector (when enabled) -->\n    {#if showModelSelector}\n      <div\n        class=\"flex items-center justify-between gap-2 px-3 py-2 border-b border-exo-medium-gray/30\"\n      >\n        <div class=\"flex items-center gap-2 flex-1\">\n          <span\n            class=\"text-xs text-exo-light-gray uppercase tracking-wider flex-shrink-0\"\n            >MODEL:</span\n          >\n          <!-- Model button — opens the full model picker -->\n          <div class=\"relative flex-1 max-w-xs\">\n            <button\n              type=\"button\"\n              onclick={() => onOpenModelPicker?.()}\n              class=\"w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-1.5 text-xs font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70\"\n            >\n              {#if currentModelLabel}\n                <span class=\"text-exo-yellow truncate\">{currentModelLabel}</span\n                >\n              {:else}\n                <span class=\"text-exo-light-gray/50\">— SELECT MODEL —</span>\n              {/if}\n            </button>\n            <div\n              class=\"absolute right-2 top-1/2 -translate-y-1/2 pointer-events-none\"\n            >\n              <svg\n                class=\"w-3 h-3 text-exo-yellow/60\"\n                fill=\"none\"\n                viewBox=\"0 0 24 24\"\n                stroke=\"currentColor\"\n              >\n                <path\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                  stroke-width=\"2\"\n                  d=\"M19 9l-7 7-7-7\"\n                />\n              </svg>\n            </div>\n          </div>\n        </div>\n        <!-- Thinking toggle -->\n        {#if modelSupportsThinking()}\n          <button\n            type=\"button\"\n            onclick={() => setConversationThinking(!thinkingEnabled)}\n            class=\"flex items-center gap-1.5 px-2 py-1 rounded text-xs font-mono tracking-wide transition-all duration-200 flex-shrink-0 cursor-pointer border {thinkingEnabled\n              ? 'bg-exo-yellow/15 border-exo-yellow/40 text-exo-yellow'\n              : 'bg-exo-medium-gray/30 border-exo-medium-gray/50 text-exo-light-gray/60 hover:text-exo-light-gray'}\"\n            title={thinkingEnabled\n              ? \"Thinking enabled — click to disable\"\n              : \"Thinking disabled — click to enable\"}\n          >\n            <svg\n              class=\"w-3.5 h-3.5\"\n              viewBox=\"0 0 24 24\"\n              fill=\"none\"\n              stroke=\"currentColor\"\n              stroke-width=\"1.5\"\n            >\n              <path\n                d=\"M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4\"\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n              />\n            </svg>\n            <span>{thinkingEnabled ? \"THINK\" : \"NO THINK\"}</span>\n          </button>\n        {/if}\n\n        <!-- Performance stats -->\n        {#if currentTtft !== null || currentTps !== null}\n          <div class=\"flex items-center gap-4 text-xs font-mono flex-shrink-0\">\n            {#if currentTtft !== null}\n              <span class=\"text-exo-light-gray\">\n                <span class=\"text-white/70\">TTFT</span>\n                <span class=\"text-exo-yellow\">{currentTtft.toFixed(1)}ms</span>\n              </span>\n            {/if}\n            {#if currentTps !== null}\n              <span class=\"text-exo-light-gray\">\n                <span class=\"text-white/70\">TPS</span>\n                <span class=\"text-exo-yellow\">{currentTps.toFixed(1)}</span>\n                <span class=\"text-white/60\">tok/s</span>\n                <span class=\"text-white/50\"\n                  >({(1000 / currentTps).toFixed(1)} ms/tok)</span\n                >\n              </span>\n            {/if}\n          </div>\n        {/if}\n      </div>\n    {/if}\n\n    <!-- Image params panel (shown for image models or edit mode) -->\n    {#if showModelSelector && (isImageModel() || isEditMode)}\n      <ImageParamsPanel {isEditMode} />\n    {/if}\n\n    <!-- Attached files preview -->\n    {#if uploadedFiles.length > 0}\n      <div class=\"px-3 pt-3\">\n        <ChatAttachments files={uploadedFiles} onRemove={handleFileRemove} />\n      </div>\n    {/if}\n\n    <!-- Input area -->\n    <div class=\"flex items-start gap-2 sm:gap-3 py-3 px-3 sm:px-4\">\n      <!-- Attach file button -->\n      <button\n        type=\"button\"\n        onclick={openFilePicker}\n        disabled={loading}\n        class=\"flex items-center justify-center w-7 h-7 rounded text-exo-light-gray hover:text-exo-yellow transition-all disabled:opacity-50 disabled:cursor-not-allowed flex-shrink-0 cursor-pointer\"\n        title=\"Attach file\"\n      >\n        <svg\n          class=\"w-4 h-4\"\n          fill=\"none\"\n          viewBox=\"0 0 24 24\"\n          stroke=\"currentColor\"\n        >\n          <path\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n            stroke-width=\"2\"\n            d=\"M15.172 7l-6.586 6.586a2 2 0 102.828 2.828l6.414-6.586a4 4 0 00-5.656-5.656l-6.415 6.585a6 6 0 108.486 8.486L20.5 13\"\n          />\n        </svg>\n      </button>\n\n      <!-- Terminal prompt -->\n      <span class=\"text-exo-yellow text-sm font-bold flex-shrink-0 leading-7\"\n        >▶</span\n      >\n\n      <textarea\n        bind:this={textareaRef}\n        bind:value={message}\n        onkeydown={handleKeydown}\n        oninput={handleInput}\n        onpaste={handlePaste}\n        placeholder={isEditOnlyWithoutImage\n          ? \"Attach an image to edit...\"\n          : isEditMode\n            ? \"Describe how to edit this image...\"\n            : isImageModel()\n              ? \"Describe the image you want to generate...\"\n              : placeholder}\n        rows={1}\n        class=\"flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none text-sm leading-7 font-mono\"\n        style=\"min-height: 28px; max-height: 150px;\"\n      ></textarea>\n\n      {#if loading}\n        <button\n          type=\"button\"\n          onclick={() => stopGeneration()}\n          class=\"px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-exo-medium-gray hover:text-white\"\n          aria-label=\"Stop generation\"\n        >\n          <span class=\"inline-flex items-center gap-1 sm:gap-2\">\n            <svg\n              class=\"w-3 h-3 sm:w-3.5 sm:h-3.5\"\n              fill=\"currentColor\"\n              viewBox=\"0 0 24 24\"\n            >\n              <rect x=\"6\" y=\"6\" width=\"12\" height=\"12\" rx=\"1\" />\n            </svg>\n            <span class=\"hidden sm:inline\">Cancel</span>\n          </span>\n        </button>\n      {:else}\n        <button\n          type=\"submit\"\n          disabled={!canSend || isEditOnlyWithoutImage}\n          class=\"px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap\n\t\t\t\t\t{!canSend || isEditOnlyWithoutImage\n            ? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'\n            : 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}\"\n          aria-label={shouldShowEditMode\n            ? \"Edit image\"\n            : isImageModel()\n              ? \"Generate image\"\n              : \"Send message\"}\n        >\n          {#if shouldShowEditMode}\n            <span class=\"inline-flex items-center gap-1.5\">\n              <svg\n                class=\"w-3.5 h-3.5\"\n                fill=\"none\"\n                viewBox=\"0 0 24 24\"\n                stroke=\"currentColor\"\n                stroke-width=\"2\"\n              >\n                <path\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                  d=\"M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z\"\n                />\n              </svg>\n              <span>EDIT</span>\n            </span>\n          {:else if isEditOnlyWithoutImage}\n            <span class=\"inline-flex items-center gap-1.5\">\n              <svg\n                class=\"w-3.5 h-3.5\"\n                fill=\"none\"\n                viewBox=\"0 0 24 24\"\n                stroke=\"currentColor\"\n                stroke-width=\"2\"\n              >\n                <path\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                  d=\"M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z\"\n                />\n              </svg>\n              <span>EDIT</span>\n            </span>\n          {:else if isImageModel()}\n            <span class=\"inline-flex items-center gap-1.5\">\n              <svg\n                class=\"w-3.5 h-3.5\"\n                fill=\"none\"\n                viewBox=\"0 0 24 24\"\n                stroke=\"currentColor\"\n                stroke-width=\"2\"\n              >\n                <rect x=\"3\" y=\"3\" width=\"18\" height=\"18\" rx=\"2\" ry=\"2\" />\n                <circle cx=\"8.5\" cy=\"8.5\" r=\"1.5\" />\n                <polyline points=\"21 15 16 10 5 21\" />\n              </svg>\n              <span>GENERATE</span>\n            </span>\n          {:else}\n            SEND\n          {/if}\n        </button>\n      {/if}\n    </div>\n\n    <!-- Bottom accent line -->\n    <div\n      class=\"absolute bottom-0 left-0 right-0 h-px bg-gradient-to-r from-transparent via-exo-yellow/30 to-transparent\"\n    ></div>\n  </div>\n\n  {#if showHelperText}\n    <p\n      class=\"mt-2 sm:mt-3 text-center text-xs sm:text-xs text-exo-light-gray tracking-[0.1em] sm:tracking-[0.15em] uppercase\"\n    >\n      <kbd\n        class=\"px-1 sm:px-1.5 py-0.5 rounded bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50\"\n        >ENTER</kbd\n      >\n      <span class=\"mx-0.5 sm:mx-1\">TO SEND</span>\n      <span class=\"text-exo-medium-gray mx-1 sm:mx-2\">|</span>\n      <kbd\n        class=\"px-1 sm:px-1.5 py-0.5 rounded bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50\"\n        >SHIFT+ENTER</kbd\n      >\n      <span class=\"mx-0.5 sm:mx-1\">NEW LINE</span>\n      <span class=\"text-exo-medium-gray mx-1 sm:mx-2\">|</span>\n      <span class=\"text-exo-light-gray\">DRAG & DROP OR PASTE FILES</span>\n    </p>\n  {/if}\n</form>\n"
  },
  {
    "path": "dashboard/src/lib/components/ChatMessages.svelte",
    "content": "<script lang=\"ts\">\n  import {\n    messages,\n    currentResponse,\n    isLoading,\n    prefillProgress,\n    deleteMessage,\n    editAndRegenerate,\n    regenerateLastResponse,\n    regenerateFromToken,\n    setEditingImage,\n  } from \"$lib/stores/app.svelte\";\n  import type { MessageAttachment } from \"$lib/stores/app.svelte\";\n  import MarkdownContent from \"./MarkdownContent.svelte\";\n  import TokenHeatmap from \"./TokenHeatmap.svelte\";\n  import PrefillProgressBar from \"./PrefillProgressBar.svelte\";\n  import ImageLightbox from \"./ImageLightbox.svelte\";\n\n  interface Props {\n    class?: string;\n    scrollParent?: HTMLElement | null;\n  }\n\n  let { class: className = \"\", scrollParent = null }: Props = $props();\n\n  const messageList = $derived(messages());\n  const response = $derived(currentResponse());\n  const loading = $derived(isLoading());\n  const prefill = $derived(prefillProgress());\n\n  // Scroll management - user controls scroll, show button when not at bottom\n  const SCROLL_THRESHOLD = 100;\n  let showScrollButton = $state(false);\n  let lastMessageCount = 0;\n  let containerRef: HTMLDivElement | undefined = $state();\n\n  function getScrollContainer(): HTMLElement | null {\n    if (scrollParent) return scrollParent;\n    return containerRef?.parentElement ?? null;\n  }\n\n  function isNearBottom(el: HTMLElement): boolean {\n    return el.scrollHeight - el.scrollTop - el.clientHeight < SCROLL_THRESHOLD;\n  }\n\n  function scrollToBottom() {\n    const el = getScrollContainer();\n    if (el) {\n      el.scrollTo({ top: el.scrollHeight, behavior: \"smooth\" });\n    }\n  }\n\n  function updateScrollButtonVisibility() {\n    const el = getScrollContainer();\n    if (!el) return;\n    showScrollButton = !isNearBottom(el);\n  }\n\n  // Attach scroll listener\n  $effect(() => {\n    const el = scrollParent ?? containerRef?.parentElement;\n    if (!el) return;\n\n    el.addEventListener(\"scroll\", updateScrollButtonVisibility, {\n      passive: true,\n    });\n    // Initial check\n    updateScrollButtonVisibility();\n    return () => el.removeEventListener(\"scroll\", updateScrollButtonVisibility);\n  });\n\n  // Auto-scroll when user sends a new message\n  $effect(() => {\n    const count = messageList.length;\n    if (count > lastMessageCount) {\n      const el = getScrollContainer();\n      if (el) {\n        requestAnimationFrame(() => {\n          el.scrollTo({ top: el.scrollHeight, behavior: \"smooth\" });\n        });\n      }\n    }\n    lastMessageCount = count;\n  });\n\n  // Update scroll button visibility when content changes\n  $effect(() => {\n    // Track response to trigger re-check during streaming\n    const _ = response;\n\n    // Small delay to let DOM update\n    requestAnimationFrame(() => updateScrollButtonVisibility());\n  });\n\n  // Edit state\n  let editingMessageId = $state<string | null>(null);\n  let editContent = $state(\"\");\n  let editTextareaRef: HTMLTextAreaElement | undefined = $state();\n\n  // Delete confirmation state\n  let deleteConfirmId = $state<string | null>(null);\n\n  // Copied state for feedback\n  let copiedMessageId = $state<string | null>(null);\n  let expandedThinkingMessageIds = $state<Set<string>>(new Set());\n\n  // Lightbox state\n  let expandedImageSrc = $state<string | null>(null);\n\n  // Uncertainty heatmap toggle\n  let heatmapMessageIds = $state<Set<string>>(new Set());\n\n  function toggleHeatmap(messageId: string) {\n    const next = new Set(heatmapMessageIds);\n    if (next.has(messageId)) {\n      next.delete(messageId);\n    } else {\n      next.add(messageId);\n    }\n    heatmapMessageIds = next;\n  }\n\n  function isHeatmapVisible(messageId: string): boolean {\n    return heatmapMessageIds.has(messageId);\n  }\n\n  function formatTimestamp(timestamp: number): string {\n    return new Date(timestamp).toLocaleTimeString(\"en-US\", {\n      hour12: false,\n      hour: \"2-digit\",\n      minute: \"2-digit\",\n      second: \"2-digit\",\n    });\n  }\n\n  function getAttachmentIcon(attachment: MessageAttachment): string {\n    switch (attachment.type) {\n      case \"image\":\n        return \"🖼\";\n      case \"text\":\n        return \"📄\";\n      default:\n        return \"📎\";\n    }\n  }\n\n  function truncateName(name: string, maxLen: number = 25): string {\n    if (name.length <= maxLen) return name;\n    const ext = name.slice(name.lastIndexOf(\".\"));\n    const base = name.slice(0, name.lastIndexOf(\".\"));\n    const available = maxLen - ext.length - 3;\n    return base.slice(0, available) + \"...\" + ext;\n  }\n\n  async function handleCopy(content: string, messageId: string) {\n    try {\n      await navigator.clipboard.writeText(content);\n      copiedMessageId = messageId;\n      setTimeout(() => {\n        copiedMessageId = null;\n      }, 2000);\n    } catch (error) {\n      console.error(\"Failed to copy:\", error);\n    }\n  }\n\n  function toggleThinkingVisibility(messageId: string) {\n    const next = new Set(expandedThinkingMessageIds);\n    if (next.has(messageId)) {\n      next.delete(messageId);\n    } else {\n      next.add(messageId);\n    }\n    expandedThinkingMessageIds = next;\n  }\n\n  function isThinkingExpanded(messageId: string): boolean {\n    return expandedThinkingMessageIds.has(messageId);\n  }\n\n  function handleStartEdit(messageId: string, content: string) {\n    editingMessageId = messageId;\n    editContent = content;\n    setTimeout(() => {\n      if (editTextareaRef) {\n        editTextareaRef.focus();\n        editTextareaRef.setSelectionRange(\n          editTextareaRef.value.length,\n          editTextareaRef.value.length,\n        );\n        // Auto-resize\n        editTextareaRef.style.height = \"auto\";\n        editTextareaRef.style.height =\n          Math.min(editTextareaRef.scrollHeight, 200) + \"px\";\n      }\n    }, 10);\n  }\n\n  function handleCancelEdit() {\n    editingMessageId = null;\n    editContent = \"\";\n  }\n\n  function handleSaveEdit() {\n    if (editingMessageId && editContent.trim()) {\n      editAndRegenerate(editingMessageId, editContent.trim());\n    }\n    editingMessageId = null;\n    editContent = \"\";\n  }\n\n  function handleEditKeydown(event: KeyboardEvent) {\n    if (event.key === \"Enter\" && !event.shiftKey) {\n      event.preventDefault();\n      handleSaveEdit();\n    } else if (event.key === \"Escape\") {\n      handleCancelEdit();\n    }\n  }\n\n  function handleEditInput() {\n    if (editTextareaRef) {\n      editTextareaRef.style.height = \"auto\";\n      editTextareaRef.style.height =\n        Math.min(editTextareaRef.scrollHeight, 200) + \"px\";\n    }\n  }\n\n  function handleDeleteClick(messageId: string) {\n    if (loading) return;\n    deleteConfirmId = messageId;\n  }\n\n  function handleConfirmDelete() {\n    if (deleteConfirmId) {\n      deleteMessage(deleteConfirmId);\n      deleteConfirmId = null;\n    }\n  }\n\n  function handleCancelDelete() {\n    deleteConfirmId = null;\n  }\n\n  function handleRegenerate() {\n    regenerateLastResponse();\n  }\n\n  // Check if a message is the last assistant message\n  function isLastAssistantMessage(messageId: string): boolean {\n    for (let i = messageList.length - 1; i >= 0; i--) {\n      if (messageList[i].role === \"assistant\") {\n        return messageList[i].id === messageId;\n      }\n    }\n    return false;\n  }\n</script>\n\n<div class=\"flex flex-col gap-4 sm:gap-6 {className}\">\n  {#each messageList as message, i (message.id)}\n    <div\n      class=\"group flex {message.role === 'user'\n        ? 'justify-end'\n        : 'justify-start'}\"\n    >\n      <div\n        class={message.role === \"user\"\n          ? \"max-w-[85%] sm:max-w-[70%] flex flex-col items-end\"\n          : \"w-full max-w-[98%] sm:max-w-[95%]\"}\n      >\n        {#if message.role === \"assistant\"}\n          <!-- Assistant message header -->\n          <div class=\"flex items-center gap-1.5 sm:gap-2 mb-1.5 sm:mb-2\">\n            <div\n              class=\"w-1.5 h-1.5 sm:w-2 sm:h-2 bg-exo-yellow rounded-full shadow-[0_0_10px_rgba(255,215,0,0.5)]\"\n            ></div>\n            <span\n              class=\"text-sm sm:text-xs text-exo-yellow tracking-[0.15em] sm:tracking-[0.2em] uppercase font-medium\"\n              >EXO</span\n            >\n            <span\n              class=\"text-xs sm:text-sm text-exo-light-gray tracking-wider tabular-nums\"\n              >{formatTimestamp(message.timestamp)}</span\n            >\n            {#if message.ttftMs || message.tps}\n              <span class=\"text-xs text-exo-light-gray/80 font-mono ml-2\">\n                {#if message.ttftMs}<span class=\"text-exo-light-gray/50\"\n                    >TTFT</span\n                  >\n                  {message.ttftMs.toFixed(\n                    0,\n                  )}ms{/if}{#if message.ttftMs && message.tps}<span\n                    class=\"text-exo-light-gray/30 mx-1\">•</span\n                  >{/if}{#if message.tps}{message.tps.toFixed(1)}\n                  <span class=\"text-exo-light-gray/50\">tok/s</span>{/if}\n              </span>\n            {/if}\n          </div>\n        {:else}\n          <!-- User message header -->\n          <div\n            class=\"flex items-center justify-end gap-1.5 sm:gap-2 mb-1.5 sm:mb-2\"\n          >\n            <span\n              class=\"text-xs sm:text-sm text-exo-light-gray tracking-wider tabular-nums\"\n              >{formatTimestamp(message.timestamp)}</span\n            >\n            <span\n              class=\"text-sm sm:text-xs text-exo-light-gray tracking-[0.1em] sm:tracking-[0.15em] uppercase\"\n              >QUERY</span\n            >\n            <div\n              class=\"w-1.5 h-1.5 sm:w-2 sm:h-2 bg-exo-light-gray/50 rounded-full\"\n            ></div>\n          </div>\n        {/if}\n\n        {#if deleteConfirmId === message.id}\n          <!-- Delete confirmation -->\n          <div class=\"bg-red-500/10 border border-red-500/30 rounded-lg p-3\">\n            <p class=\"text-xs text-red-400 mb-3\">\n              {#if i === messageList.length - 1}\n                Delete this message?\n              {:else}\n                Delete this message and all messages after it?\n              {/if}\n            </p>\n            <div class=\"flex gap-2 justify-end\">\n              <button\n                onclick={handleCancelDelete}\n                class=\"px-3 py-1.5 text-sm font-mono tracking-wider uppercase bg-exo-medium-gray/20 text-exo-light-gray border border-exo-medium-gray/30 rounded hover:bg-exo-medium-gray/30 transition-colors cursor-pointer\"\n              >\n                CANCEL\n              </button>\n              <button\n                onclick={handleConfirmDelete}\n                class=\"px-3 py-1.5 text-sm font-mono tracking-wider uppercase bg-red-500/20 text-red-400 border border-red-500/30 rounded hover:bg-red-500/30 transition-colors cursor-pointer\"\n              >\n                DELETE\n              </button>\n            </div>\n          </div>\n        {:else if editingMessageId === message.id}\n          <!-- Edit mode -->\n          <div class=\"command-panel rounded-lg p-3\">\n            <textarea\n              bind:this={editTextareaRef}\n              bind:value={editContent}\n              onkeydown={handleEditKeydown}\n              oninput={handleEditInput}\n              class=\"w-full bg-exo-black/60 border border-exo-yellow/30 rounded px-3 py-2 text-sm text-foreground font-mono focus:outline-none focus:border-exo-yellow/50 resize-none\"\n              style=\"min-height: 60px; max-height: 200px;\"\n            ></textarea>\n            <div class=\"flex gap-2 justify-end mt-2\">\n              <button\n                onclick={handleCancelEdit}\n                class=\"px-3 py-1.5 text-sm font-mono tracking-wider uppercase bg-exo-medium-gray/20 text-exo-light-gray border border-exo-medium-gray/30 rounded hover:bg-exo-medium-gray/30 transition-colors cursor-pointer\"\n              >\n                CANCEL\n              </button>\n              <button\n                onclick={handleSaveEdit}\n                disabled={!editContent.trim()}\n                class=\"px-3 py-1.5 text-sm font-mono tracking-wider uppercase bg-transparent text-exo-yellow border border-exo-yellow/30 rounded hover:border-exo-yellow/50 transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-1.5 cursor-pointer\"\n              >\n                <svg\n                  class=\"w-3 h-3\"\n                  fill=\"none\"\n                  viewBox=\"0 0 24 24\"\n                  stroke=\"currentColor\"\n                >\n                  <path\n                    stroke-linecap=\"round\"\n                    stroke-linejoin=\"round\"\n                    stroke-width=\"2\"\n                    d=\"M12 19l9 2-9-18-9 18 9-2zm0 0v-8\"\n                  />\n                </svg>\n                SEND\n              </button>\n            </div>\n          </div>\n        {:else}\n          <div\n            class={message.role === \"user\"\n              ? \"command-panel rounded-lg rounded-tr-sm inline-block\"\n              : \"command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 block w-full\"}\n          >\n            {#if message.role === \"user\"}\n              <!-- User message styling -->\n              <div class=\"px-4 py-3\">\n                <!-- Attachments -->\n                {#if message.attachments && message.attachments.length > 0}\n                  <div class=\"flex flex-wrap gap-2 mb-3\">\n                    {#each message.attachments as attachment}\n                      <div\n                        class=\"flex items-center gap-2 bg-exo-dark-gray/60 border border-exo-yellow/20 rounded px-2 py-1 text-xs font-mono\"\n                      >\n                        {#if attachment.type === \"image\" && attachment.preview}\n                          <!-- svelte-ignore a11y_no_noninteractive_element_interactions, a11y_click_events_have_key_events -->\n                          <img\n                            src={attachment.preview}\n                            alt={attachment.name}\n                            class=\"w-12 h-12 object-cover rounded border border-exo-yellow/20 cursor-pointer hover:border-exo-yellow/50 transition-colors\"\n                            onclick={() => {\n                              if (attachment.preview)\n                                expandedImageSrc = attachment.preview;\n                            }}\n                          />\n                        {:else}\n                          <span>{getAttachmentIcon(attachment)}</span>\n                        {/if}\n                        <span class=\"text-exo-yellow\" title={attachment.name}\n                          >{truncateName(attachment.name)}</span\n                        >\n                      </div>\n                    {/each}\n                  </div>\n                {/if}\n\n                {#if message.content}\n                  <div\n                    class=\"text-xs text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed\"\n                  >\n                    {message.content}\n                  </div>\n                {/if}\n              </div>\n            {:else}\n              <!-- Assistant message styling -->\n              <div class=\"p-3 sm:p-4\">\n                {#if loading && isLastAssistantMessage(message.id) && prefill && !message.content}\n                  <PrefillProgressBar progress={prefill} class=\"mb-3\" />\n                {/if}\n                {#if message.thinking && message.thinking.trim().length > 0}\n                  <div\n                    class=\"mb-3 rounded border border-exo-yellow/20 bg-exo-black/40\"\n                  >\n                    <button\n                      type=\"button\"\n                      class=\"w-full flex items-center justify-between px-3 py-2 text-xs font-mono uppercase tracking-[0.2em] text-exo-light-gray/80 hover:text-exo-yellow transition-colors cursor-pointer\"\n                      onclick={() => toggleThinkingVisibility(message.id)}\n                      aria-expanded={isThinkingExpanded(message.id)}\n                      aria-controls={`thinking-panel-${message.id}`}\n                    >\n                      <span class=\"flex items-center gap-2 tracking-[0.25em]\">\n                        <svg\n                          class={`w-3.5 h-3.5 text-current transition-transform duration-200 ${isThinkingExpanded(message.id) ? \"rotate-90\" : \"\"}`}\n                          fill=\"none\"\n                          viewBox=\"0 0 24 24\"\n                          stroke=\"currentColor\"\n                          aria-hidden=\"true\"\n                        >\n                          <path\n                            stroke-linecap=\"round\"\n                            stroke-linejoin=\"round\"\n                            stroke-width=\"2\"\n                            d=\"M9 5l7 7-7 7\"\n                          />\n                        </svg>\n                        <span>Thinking...</span>\n                      </span>\n                      <span\n                        class=\"text-[10px] tracking-[0.2em] text-exo-light-gray/60 ml-4\"\n                      >\n                        {isThinkingExpanded(message.id) ? \"HIDE\" : \"SHOW\"}\n                      </span>\n                    </button>\n                    {#if isThinkingExpanded(message.id)}\n                      <div\n                        id={`thinking-panel-${message.id}`}\n                        class=\"px-3 pb-3 text-xs text-exo-light-gray/90 font-mono whitespace-pre-wrap break-words leading-relaxed\"\n                      >\n                        {message.thinking.trim()}\n                      </div>\n                    {/if}\n                  </div>\n                {/if}\n\n                <!-- Generated Images -->\n                {#if message.attachments?.some((a) => a.type === \"generated-image\")}\n                  <div class=\"mb-3\">\n                    {#each message.attachments.filter((a) => a.type === \"generated-image\") as attachment}\n                      <div class=\"relative group/img inline-block\">\n                        <!-- svelte-ignore a11y_no_noninteractive_element_interactions, a11y_click_events_have_key_events -->\n                        <img\n                          src={attachment.preview}\n                          alt=\"\"\n                          class=\"max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20 cursor-pointer\"\n                          onclick={() => {\n                            if (attachment.preview)\n                              expandedImageSrc = attachment.preview;\n                          }}\n                        />\n                        <!-- Button overlay -->\n                        <div\n                          class=\"absolute top-2 right-2 flex gap-1 opacity-0 group-hover/img:opacity-100 transition-opacity\"\n                        >\n                          <!-- Expand button -->\n                          <button\n                            type=\"button\"\n                            class=\"p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer\"\n                            onclick={() => {\n                              if (attachment.preview)\n                                expandedImageSrc = attachment.preview;\n                            }}\n                            title=\"Expand image\"\n                          >\n                            <svg\n                              class=\"w-4 h-4\"\n                              fill=\"none\"\n                              viewBox=\"0 0 24 24\"\n                              stroke=\"currentColor\"\n                              stroke-width=\"2\"\n                            >\n                              <path\n                                stroke-linecap=\"round\"\n                                stroke-linejoin=\"round\"\n                                d=\"M4 8V4m0 0h4M4 4l5 5m11-1V4m0 0h-4m4 0l-5 5M4 16v4m0 0h4m-4 0l5-5m11 5l-5-5m5 5v-4m0 4h-4\"\n                              />\n                            </svg>\n                          </button>\n                          <!-- Edit button -->\n                          <button\n                            type=\"button\"\n                            class=\"p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer\"\n                            onclick={() => {\n                              if (attachment.preview) {\n                                setEditingImage(attachment.preview, message);\n                              }\n                            }}\n                            title=\"Edit image\"\n                          >\n                            <svg\n                              class=\"w-4 h-4\"\n                              fill=\"none\"\n                              viewBox=\"0 0 24 24\"\n                              stroke=\"currentColor\"\n                              stroke-width=\"2\"\n                            >\n                              <path\n                                stroke-linecap=\"round\"\n                                stroke-linejoin=\"round\"\n                                d=\"M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z\"\n                              />\n                            </svg>\n                          </button>\n                          <!-- Download button -->\n                          <button\n                            type=\"button\"\n                            class=\"p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer\"\n                            onclick={() => {\n                              if (attachment.preview) {\n                                const link = document.createElement(\"a\");\n                                link.href = attachment.preview;\n                                const ext =\n                                  attachment.name?.split(\".\").pop() || \"png\";\n                                link.download = `generated-image-${Date.now()}.${ext}`;\n                                link.click();\n                              }\n                            }}\n                            title=\"Download image\"\n                          >\n                            <svg\n                              class=\"w-4 h-4\"\n                              fill=\"none\"\n                              viewBox=\"0 0 24 24\"\n                              stroke=\"currentColor\"\n                              stroke-width=\"2\"\n                            >\n                              <path\n                                stroke-linecap=\"round\"\n                                stroke-linejoin=\"round\"\n                                d=\"M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4\"\n                              />\n                            </svg>\n                          </button>\n                        </div>\n                      </div>\n                    {/each}\n                  </div>\n                {/if}\n\n                <div class=\"text-xs text-foreground\">\n                  {#if message.content === \"Generating image...\" || message.content === \"Editing image...\" || message.content?.startsWith(\"Generating...\") || message.content?.startsWith(\"Editing...\")}\n                    <div class=\"flex items-center gap-3 text-exo-yellow\">\n                      <div class=\"relative\">\n                        <div\n                          class=\"w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin\"\n                        ></div>\n                        <svg\n                          class=\"absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60\"\n                          fill=\"none\"\n                          viewBox=\"0 0 24 24\"\n                          stroke=\"currentColor\"\n                          stroke-width=\"2\"\n                        >\n                          <rect\n                            x=\"3\"\n                            y=\"3\"\n                            width=\"18\"\n                            height=\"18\"\n                            rx=\"2\"\n                            ry=\"2\"\n                          />\n                          <circle cx=\"8.5\" cy=\"8.5\" r=\"1.5\" />\n                          <polyline points=\"21 15 16 10 5 21\" />\n                        </svg>\n                      </div>\n                      <span class=\"font-mono tracking-wider uppercase text-sm\"\n                        >{message.content}</span\n                      >\n                    </div>\n                  {:else if message.content || (loading && !message.attachments?.some((a) => a.type === \"generated-image\"))}\n                    {#if isHeatmapVisible(message.id) && message.tokens && message.tokens.length > 0}\n                      <TokenHeatmap\n                        tokens={message.tokens}\n                        isGenerating={loading &&\n                          isLastAssistantMessage(message.id)}\n                        onRegenerateFrom={(tokenIndex) =>\n                          regenerateFromToken(message.id, tokenIndex)}\n                      />\n                    {:else}\n                      <MarkdownContent\n                        content={message.content || (loading ? response : \"\")}\n                      />\n                      {#if loading && !message.content}\n                        <span\n                          class=\"inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink\"\n                        ></span>\n                      {/if}\n                    {/if}\n                  {/if}\n                </div>\n              </div>\n            {/if}\n          </div>\n\n          <!-- Action buttons -->\n          <div\n            class=\"flex items-center gap-1 mt-1.5 opacity-0 group-hover:opacity-100 transition-opacity {message.role ===\n            'user'\n              ? 'justify-end'\n              : 'justify-start'}\"\n          >\n            <!-- Copy button -->\n            <button\n              onclick={() => handleCopy(message.content, message.id)}\n              class=\"p-1.5 text-exo-light-gray hover:text-exo-yellow transition-colors rounded cursor-pointer\"\n              title=\"Copy message\"\n            >\n              {#if copiedMessageId === message.id}\n                <svg\n                  class=\"w-3.5 h-3.5 text-green-400\"\n                  fill=\"none\"\n                  viewBox=\"0 0 24 24\"\n                  stroke=\"currentColor\"\n                >\n                  <path\n                    stroke-linecap=\"round\"\n                    stroke-linejoin=\"round\"\n                    stroke-width=\"2\"\n                    d=\"M5 13l4 4L19 7\"\n                  />\n                </svg>\n              {:else}\n                <svg\n                  class=\"w-3.5 h-3.5\"\n                  fill=\"none\"\n                  viewBox=\"0 0 24 24\"\n                  stroke=\"currentColor\"\n                >\n                  <path\n                    stroke-linecap=\"round\"\n                    stroke-linejoin=\"round\"\n                    stroke-width=\"2\"\n                    d=\"M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z\"\n                  />\n                </svg>\n              {/if}\n            </button>\n\n            <!-- Edit button (user messages only) -->\n            {#if message.role === \"user\"}\n              <button\n                onclick={() => handleStartEdit(message.id, message.content)}\n                class=\"p-1.5 text-exo-light-gray hover:text-exo-yellow transition-colors rounded cursor-pointer\"\n                title=\"Edit message\"\n              >\n                <svg\n                  class=\"w-3.5 h-3.5\"\n                  fill=\"none\"\n                  viewBox=\"0 0 24 24\"\n                  stroke=\"currentColor\"\n                >\n                  <path\n                    stroke-linecap=\"round\"\n                    stroke-linejoin=\"round\"\n                    stroke-width=\"2\"\n                    d=\"M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z\"\n                  />\n                </svg>\n              </button>\n            {/if}\n\n            <!-- Uncertainty heatmap toggle (assistant messages with tokens) -->\n            {#if message.role === \"assistant\" && message.tokens && message.tokens.length > 0}\n              <button\n                onclick={() => toggleHeatmap(message.id)}\n                class=\"p-1.5 transition-colors rounded cursor-pointer {isHeatmapVisible(\n                  message.id,\n                )\n                  ? 'text-exo-yellow'\n                  : 'text-exo-light-gray hover:text-exo-yellow'}\"\n                title={isHeatmapVisible(message.id)\n                  ? \"Hide uncertainty heatmap\"\n                  : \"Show uncertainty heatmap\"}\n              >\n                <svg\n                  class=\"w-3.5 h-3.5\"\n                  fill=\"none\"\n                  viewBox=\"0 0 24 24\"\n                  stroke=\"currentColor\"\n                >\n                  <path\n                    stroke-linecap=\"round\"\n                    stroke-linejoin=\"round\"\n                    stroke-width=\"2\"\n                    d=\"M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z\"\n                  />\n                </svg>\n              </button>\n            {/if}\n\n            <!-- Regenerate button (last assistant message only) -->\n            {#if message.role === \"assistant\" && isLastAssistantMessage(message.id) && !loading}\n              <button\n                onclick={handleRegenerate}\n                class=\"p-1.5 text-exo-light-gray hover:text-exo-yellow transition-colors rounded cursor-pointer\"\n                title=\"Regenerate response\"\n              >\n                <svg\n                  class=\"w-3.5 h-3.5\"\n                  fill=\"none\"\n                  viewBox=\"0 0 24 24\"\n                  stroke=\"currentColor\"\n                >\n                  <path\n                    stroke-linecap=\"round\"\n                    stroke-linejoin=\"round\"\n                    stroke-width=\"2\"\n                    d=\"M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15\"\n                  />\n                </svg>\n              </button>\n            {/if}\n\n            <!-- Delete button -->\n            <button\n              onclick={() => handleDeleteClick(message.id)}\n              disabled={loading}\n              class=\"p-1.5 transition-colors rounded {loading\n                ? 'text-exo-light-gray/30 cursor-not-allowed'\n                : 'text-exo-light-gray hover:text-red-400 hover:bg-red-500/10 cursor-pointer'}\"\n              title={loading\n                ? \"Cannot delete while generating\"\n                : \"Delete message\"}\n            >\n              <svg\n                class=\"w-3.5 h-3.5\"\n                fill=\"none\"\n                viewBox=\"0 0 24 24\"\n                stroke=\"currentColor\"\n              >\n                <path\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                  stroke-width=\"2\"\n                  d=\"M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16\"\n                />\n              </svg>\n            </button>\n          </div>\n        {/if}\n      </div>\n    </div>\n  {/each}\n\n  {#if messageList.length === 0}\n    <div\n      class=\"flex-1 flex flex-col items-center justify-center text-center pt-[20vh]\"\n    >\n      <div\n        class=\"w-12 h-12 sm:w-16 sm:h-16 border border-exo-yellow/20 rounded-full flex items-center justify-center mb-3 sm:mb-4\"\n      >\n        <div\n          class=\"w-6 h-6 sm:w-8 sm:h-8 border border-exo-yellow/40 rounded-full flex items-center justify-center\"\n        >\n          <div\n            class=\"w-1.5 h-1.5 sm:w-2 sm:h-2 bg-exo-yellow/60 rounded-full\"\n          ></div>\n        </div>\n      </div>\n      <p\n        class=\"text-xs sm:text-sm text-exo-light-gray tracking-[0.15em] sm:tracking-[0.2em] uppercase\"\n      >\n        AWAITING INPUT\n      </p>\n      <p class=\"text-xs text-white/30 tracking-wider mt-1.5 font-mono\">\n        Type a message below &middot; Shift+Enter for newline\n      </p>\n    </div>\n  {/if}\n\n  <!-- Invisible element for container reference -->\n  <div bind:this={containerRef}></div>\n\n  <!-- Scroll to bottom button -->\n  {#if showScrollButton}\n    <button\n      type=\"button\"\n      onclick={scrollToBottom}\n      class=\"sticky bottom-4 left-1/2 -translate-x-1/2 w-10 h-10 rounded-full bg-exo-dark-gray/90 border border-exo-medium-gray/50 flex items-center justify-center text-exo-light-gray hover:text-exo-yellow hover:border-exo-yellow/50 transition-all shadow-lg cursor-pointer z-10\"\n      title=\"Scroll to bottom\"\n      aria-label=\"Scroll to bottom of messages\"\n    >\n      <svg\n        class=\"w-5 h-5\"\n        fill=\"none\"\n        viewBox=\"0 0 24 24\"\n        stroke=\"currentColor\"\n      >\n        <path\n          stroke-linecap=\"round\"\n          stroke-linejoin=\"round\"\n          stroke-width=\"2\"\n          d=\"M19 14l-7 7m0 0l-7-7m7 7V3\"\n        />\n      </svg>\n    </button>\n  {/if}\n</div>\n\n<ImageLightbox\n  src={expandedImageSrc}\n  onclose={() => (expandedImageSrc = null)}\n/>\n"
  },
  {
    "path": "dashboard/src/lib/components/ChatModelSelector.svelte",
    "content": "<script lang=\"ts\" module>\n  export interface ChatModelInfo {\n    id: string;\n    name: string;\n    base_model: string;\n    storage_size_megabytes: number;\n    capabilities: string[];\n    family: string;\n    quantization: string;\n  }\n\n  // Auto mode tier list (for when user just starts typing)\n  export const AUTO_TIERS: string[][] = [\n    // Tier 1 (frontier)\n    [\"DeepSeek V3.1\", \"GLM-5\", \"Kimi K2.5\", \"Qwen3 Coder Next\"],\n    // Tier 2 (excellent)\n    [\n      \"Kimi K2\",\n      \"Qwen3 235B\",\n      \"MiniMax M2.5\",\n      \"Step 3.5 Flash\",\n      \"Qwen3 Next 80B\",\n    ],\n    // Tier 3 (great)\n    [\n      \"GLM 4.7\",\n      \"MiniMax M2.1\",\n      \"Qwen3 Coder 480B\",\n      \"GLM 4.5 Air\",\n      \"Llama 3.3 70B\",\n    ],\n    // Tier 4 (good)\n    [\"GPT-OSS 120B\", \"Qwen3 30B\", \"Llama 3.1 70B\", \"GLM 4.7 Flash\"],\n    // Tier 5 (small/fast)\n    [\n      \"Llama 3.1 8B\",\n      \"GPT-OSS 20B\",\n      \"Llama 3.2 3B\",\n      \"Qwen3 0.6B\",\n      \"Llama 3.2 1B\",\n    ],\n  ];\n\n  /** Return the tier index (0 = best) for a base_model name. */\n  export function getAutoTierIndex(baseModel: string): number {\n    for (let i = 0; i < AUTO_TIERS.length; i++) {\n      if (AUTO_TIERS[i].includes(baseModel)) return i;\n    }\n    return AUTO_TIERS.length; // not in any tier → lowest priority\n  }\n\n  /** Auto mode: walk tiers top-down, pick biggest fitting variant from highest tier. */\n  export function pickAutoModel(\n    modelList: ChatModelInfo[],\n    memoryGB: number,\n  ): ChatModelInfo | null {\n    for (const tier of AUTO_TIERS) {\n      const candidates: ChatModelInfo[] = [];\n      for (const baseModel of tier) {\n        const variants = modelList\n          .filter(\n            (m) =>\n              m.base_model === baseModel &&\n              (m.storage_size_megabytes || 0) / 1024 <= memoryGB &&\n              (m.storage_size_megabytes || 0) > 0,\n          )\n          .sort(\n            (a, b) =>\n              (b.storage_size_megabytes || 0) - (a.storage_size_megabytes || 0),\n          );\n        if (variants[0]) candidates.push(variants[0]);\n      }\n      if (candidates.length > 0) {\n        candidates.sort(\n          (a, b) =>\n            (b.storage_size_megabytes || 0) - (a.storage_size_megabytes || 0),\n        );\n        return candidates[0];\n      }\n    }\n    return null;\n  }\n</script>\n\n<script lang=\"ts\">\n  interface CategoryRecommendation {\n    category: \"coding\" | \"writing\" | \"agentic\" | \"biggest\";\n    label: string;\n    model: ChatModelInfo | null;\n    tooltip: string;\n  }\n\n  interface Props {\n    models: ChatModelInfo[];\n    clusterLabel: string;\n    totalMemoryGB: number;\n    onSelect: (modelId: string, category: string) => void;\n    onAddModel: () => void;\n    class?: string;\n  }\n\n  let {\n    models,\n    clusterLabel,\n    totalMemoryGB,\n    onSelect,\n    onAddModel,\n    class: className = \"\",\n  }: Props = $props();\n\n  // --- Hardcoded Rankings ---\n  const CODING_RANKING = [\n    \"Qwen3 Coder Next\",\n    \"Qwen3 Coder 480B\",\n    \"Qwen3 30B\",\n    \"GPT-OSS 20B\",\n    \"Llama 3.1 8B\",\n    \"Llama 3.2 3B\",\n    \"Qwen3 0.6B\",\n  ];\n\n  const WRITING_RANKING = [\n    \"Kimi K2.5\",\n    \"Kimi K2\",\n    \"Qwen3 Next 80B\",\n    \"Llama 3.3 70B\",\n    \"MiniMax M2.5\",\n    \"GLM 4.5 Air\",\n    \"GLM 4.7 Flash\",\n    \"GPT-OSS 20B\",\n    \"Llama 3.1 8B\",\n    \"Llama 3.2 3B\",\n    \"Qwen3 0.6B\",\n  ];\n\n  const AGENTIC_RANKING = [\n    \"DeepSeek V3.1\",\n    \"GLM-5\",\n    \"Qwen3 235B\",\n    \"Step 3.5 Flash\",\n    \"GLM 4.7\",\n    \"MiniMax M2.1\",\n    \"GPT-OSS 120B\",\n    \"Llama 3.3 70B\",\n    \"Llama 3.1 70B\",\n    \"GLM 4.7 Flash\",\n    \"GPT-OSS 20B\",\n    \"Qwen3 30B\",\n    \"Llama 3.1 8B\",\n    \"Llama 3.2 3B\",\n    \"Qwen3 0.6B\",\n  ];\n\n  function getModelSizeGB(m: ChatModelInfo): number {\n    return (m.storage_size_megabytes || 0) / 1024;\n  }\n\n  function fitsInMemory(m: ChatModelInfo): boolean {\n    return getModelSizeGB(m) <= totalMemoryGB && getModelSizeGB(m) > 0;\n  }\n\n  /** For a given base_model name, find the biggest quant variant that fits in memory. */\n  function pickBestVariant(baseModel: string): ChatModelInfo | null {\n    const variants = models\n      .filter((m) => m.base_model === baseModel && fitsInMemory(m))\n      .sort((a, b) => getModelSizeGB(b) - getModelSizeGB(a));\n    return variants[0] ?? null;\n  }\n\n  /** Walk a ranked list of base_model names, return the first that has a fitting variant. */\n  function pickFromRanking(ranking: string[]): ChatModelInfo | null {\n    for (const baseModel of ranking) {\n      const pick = pickBestVariant(baseModel);\n      if (pick) return pick;\n    }\n    return null;\n  }\n\n  /** Pick the single biggest model that fits. */\n  function pickBiggest(): ChatModelInfo | null {\n    const fitting = models\n      .filter((m) => fitsInMemory(m))\n      .sort((a, b) => getModelSizeGB(b) - getModelSizeGB(a));\n    return fitting[0] ?? null;\n  }\n\n  const recommendations = $derived.by((): CategoryRecommendation[] => {\n    return [\n      {\n        category: \"coding\",\n        label: \"Best for Coding\",\n        model: pickFromRanking(CODING_RANKING),\n        tooltip:\n          \"Ranked by coding benchmark performance (LiveCodeBench, SWE-bench)\",\n      },\n      {\n        category: \"writing\",\n        label: \"Best for Writing\",\n        model: pickFromRanking(WRITING_RANKING),\n        tooltip: \"Ranked by creative writing quality and instruction following\",\n      },\n      {\n        category: \"agentic\",\n        label: \"Best Agentic\",\n        model: pickFromRanking(AGENTIC_RANKING),\n        tooltip: \"Ranked by reasoning, planning, and tool-use capability\",\n      },\n      {\n        category: \"biggest\",\n        label: \"Biggest\",\n        model: pickBiggest(),\n        tooltip: \"Largest model that fits in your available memory\",\n      },\n    ];\n  });\n\n  function formatSize(mb: number): string {\n    const gb = mb / 1024;\n    if (gb >= 100) return `${Math.round(gb)} GB`;\n    return `${gb.toFixed(1)} GB`;\n  }\n\n  // Category icons (SVG paths)\n  const categoryIcons: Record<string, string> = {\n    coding:\n      \"M8 9l3 3-3 3m5 0h3M5 20h14a2 2 0 002-2V6a2 2 0 00-2-2H5a2 2 0 00-2 2v12a2 2 0 002 2z\",\n    writing:\n      \"M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z\",\n    agentic:\n      \"M9.663 17h4.673M12 3v1m6.364 1.636l-.707.707M21 12h-1M4 12H3m3.343-5.657l-.707-.707m2.828 9.9a5 5 0 117.072 0l-.548.547A3.374 3.374 0 0014 18.469V19a2 2 0 11-4 0v-.531c0-.895-.356-1.754-.988-2.386l-.548-.547z\",\n    biggest:\n      \"M19 11H5m14 0a2 2 0 012 2v6a2 2 0 01-2 2H5a2 2 0 01-2-2v-6a2 2 0 012-2m14 0V9a2 2 0 00-2-2M5 11V9a2 2 0 012-2m0 0V5a2 2 0 012-2h6a2 2 0 012 2v2M7 7h10\",\n  };\n\n  let hoveredTooltip = $state<string | null>(null);\n  let tooltipAnchor = $state<{ x: number; y: number } | null>(null);\n\n  function showTooltip(category: string, e: MouseEvent | FocusEvent) {\n    hoveredTooltip = category;\n    const target = e.currentTarget as HTMLElement;\n    const rect = target.getBoundingClientRect();\n    tooltipAnchor = { x: rect.left + rect.width / 2, y: rect.top };\n  }\n\n  function hideTooltip() {\n    hoveredTooltip = null;\n    tooltipAnchor = null;\n  }\n</script>\n\n<div class=\"flex flex-col items-center justify-center gap-6 {className}\">\n  <!-- Header -->\n  <div class=\"text-center\">\n    <p class=\"text-xs text-exo-light-gray uppercase tracking-[0.2em] mb-1\">\n      Recommended for your\n    </p>\n    <p class=\"text-sm text-white font-mono tracking-wide\">{clusterLabel}</p>\n  </div>\n\n  <!-- Category Cards Grid -->\n  <div class=\"grid grid-cols-2 gap-3 w-full max-w-md\">\n    {#each recommendations as rec}\n      {#if rec.model}\n        <button\n          type=\"button\"\n          onclick={() => rec.model && onSelect(rec.model.id, rec.category)}\n          class=\"group relative flex flex-col items-start gap-2 p-4 rounded-lg border border-exo-medium-gray/50 bg-exo-dark-gray/50 hover:border-exo-yellow/40 hover:bg-exo-dark-gray transition-all duration-200 cursor-pointer text-left\"\n        >\n          <!-- Category icon + label -->\n          <div class=\"flex items-center gap-2 w-full\">\n            <svg\n              class=\"w-4 h-4 text-exo-yellow/70 group-hover:text-exo-yellow transition-colors flex-shrink-0\"\n              fill=\"none\"\n              viewBox=\"0 0 24 24\"\n              stroke=\"currentColor\"\n              stroke-width=\"1.5\"\n            >\n              <path\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n                d={categoryIcons[rec.category]}\n              />\n            </svg>\n            <span\n              class=\"text-xs font-mono uppercase tracking-wider text-exo-light-gray group-hover:text-white transition-colors\"\n            >\n              {rec.label}\n            </span>\n            <!-- Info tooltip -->\n            <div class=\"ml-auto flex-shrink-0\">\n              <span\n                role=\"button\"\n                tabindex=\"-1\"\n                class=\"text-exo-light-gray/40 hover:text-exo-light-gray transition-colors cursor-help inline-flex\"\n                onmouseenter={(e: MouseEvent) => showTooltip(rec.category, e)}\n                onmouseleave={() => hideTooltip()}\n                onclick={(e: MouseEvent) => {\n                  e.stopPropagation();\n                  if (hoveredTooltip === rec.category) {\n                    hideTooltip();\n                  } else {\n                    showTooltip(rec.category, e);\n                  }\n                }}\n              >\n                <svg\n                  class=\"w-3.5 h-3.5\"\n                  fill=\"none\"\n                  viewBox=\"0 0 24 24\"\n                  stroke=\"currentColor\"\n                  stroke-width=\"2\"\n                >\n                  <circle cx=\"12\" cy=\"12\" r=\"10\" />\n                  <path d=\"M12 16v-4m0-4h.01\" />\n                </svg>\n              </span>\n            </div>\n          </div>\n\n          <!-- Model name + size -->\n          <div class=\"w-full\">\n            <p class=\"text-sm text-white font-mono truncate\">\n              {rec.model.base_model}\n            </p>\n            <p class=\"text-xs text-exo-light-gray/60 font-mono mt-0.5\">\n              {formatSize(rec.model.storage_size_megabytes)}\n              {#if rec.model.quantization}\n                <span class=\"text-exo-light-gray/40\"\n                  >&middot; {rec.model.quantization}</span\n                >\n              {/if}\n            </p>\n          </div>\n        </button>\n      {:else}\n        <!-- No model fits for this category -->\n        <div\n          class=\"flex flex-col items-start gap-2 p-4 rounded-lg border border-exo-medium-gray/30 bg-exo-dark-gray/30 opacity-50\"\n        >\n          <div class=\"flex items-center gap-2\">\n            <svg\n              class=\"w-4 h-4 text-exo-light-gray/40 flex-shrink-0\"\n              fill=\"none\"\n              viewBox=\"0 0 24 24\"\n              stroke=\"currentColor\"\n              stroke-width=\"1.5\"\n            >\n              <path\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n                d={categoryIcons[rec.category]}\n              />\n            </svg>\n            <span\n              class=\"text-xs font-mono uppercase tracking-wider text-exo-light-gray/50\"\n              >{rec.label}</span\n            >\n          </div>\n          <p class=\"text-xs text-exo-light-gray/40 font-mono\">No model fits</p>\n        </div>\n      {/if}\n    {/each}\n  </div>\n\n  <!-- Add Model Button -->\n  <button\n    type=\"button\"\n    onclick={onAddModel}\n    class=\"flex items-center gap-2 px-4 py-2 text-xs font-mono uppercase tracking-wider text-exo-light-gray hover:text-exo-yellow border border-exo-medium-gray/30 hover:border-exo-yellow/30 rounded-lg transition-all duration-200 cursor-pointer\"\n  >\n    <svg\n      class=\"w-3.5 h-3.5\"\n      fill=\"none\"\n      viewBox=\"0 0 24 24\"\n      stroke=\"currentColor\"\n      stroke-width=\"2\"\n    >\n      <path stroke-linecap=\"round\" stroke-linejoin=\"round\" d=\"M12 4v16m8-8H4\" />\n    </svg>\n    Add Model\n  </button>\n\n  <!-- Auto hint -->\n  <p class=\"text-xs text-exo-light-gray/40 font-mono tracking-wide text-center\">\n    Or just start typing &mdash; we'll pick the best model automatically\n  </p>\n</div>\n\n<!-- Fixed-position tooltip (escapes overflow-hidden ancestors) -->\n{#if hoveredTooltip && tooltipAnchor}\n  {@const rec = recommendations.find((r) => r.category === hoveredTooltip)}\n  {#if rec}\n    <div\n      class=\"fixed z-[9999] px-3 py-2 bg-exo-black border border-exo-medium-gray/50 rounded text-xs text-exo-light-gray whitespace-nowrap shadow-lg pointer-events-none\"\n      style=\"left: {tooltipAnchor.x}px; top: {tooltipAnchor.y -\n        8}px; transform: translate(-50%, -100%);\"\n    >\n      {rec.tooltip}\n    </div>\n  {/if}\n{/if}\n"
  },
  {
    "path": "dashboard/src/lib/components/ChatSidebar.svelte",
    "content": "<script lang=\"ts\">\n  import {\n    conversations,\n    activeConversationId,\n    loadConversation,\n    deleteConversation,\n    deleteAllConversations,\n    renameConversation,\n    clearChat,\n    instances,\n    debugMode,\n    toggleDebugMode,\n    topologyOnlyMode,\n    toggleTopologyOnlyMode,\n  } from \"$lib/stores/app.svelte\";\n\n  interface Props {\n    class?: string;\n    onNewChat?: () => void;\n    onSelectConversation?: () => void;\n    isMobileDrawer?: boolean;\n    isOpen?: boolean;\n    onClose?: () => void;\n  }\n\n  let {\n    class: className = \"\",\n    onNewChat,\n    onSelectConversation,\n    isMobileDrawer = false,\n    isOpen = false,\n    onClose,\n  }: Props = $props();\n\n  const conversationList = $derived(conversations());\n  const activeId = $derived(activeConversationId());\n  const instanceData = $derived(instances());\n  const debugEnabled = $derived(debugMode());\n  const topologyOnlyEnabled = $derived(topologyOnlyMode());\n\n  let searchQuery = $state(\"\");\n  let editingId = $state<string | null>(null);\n  let editingName = $state(\"\");\n  let deleteConfirmId = $state<string | null>(null);\n  let showDeleteAllConfirm = $state(false);\n\n  const filteredConversations = $derived(\n    searchQuery.trim()\n      ? conversationList.filter((c) =>\n          c.name.toLowerCase().includes(searchQuery.toLowerCase()),\n        )\n      : conversationList,\n  );\n\n  function handleNewChat() {\n    onNewChat?.();\n  }\n\n  function handleSelectConversation(id: string) {\n    onSelectConversation?.();\n    loadConversation(id);\n    // Close mobile drawer when selecting a conversation\n    if (isMobileDrawer && isOpen) {\n      onClose?.();\n    }\n  }\n\n  function handleStartEdit(id: string, name: string, event: MouseEvent) {\n    event.stopPropagation();\n    editingId = id;\n    editingName = name;\n  }\n\n  function handleSaveEdit() {\n    if (editingId && editingName.trim()) {\n      renameConversation(editingId, editingName.trim());\n    }\n    editingId = null;\n    editingName = \"\";\n  }\n\n  function handleCancelEdit() {\n    editingId = null;\n    editingName = \"\";\n  }\n\n  function handleEditKeydown(event: KeyboardEvent) {\n    if (event.key === \"Enter\") {\n      handleSaveEdit();\n    } else if (event.key === \"Escape\") {\n      handleCancelEdit();\n    }\n  }\n\n  function handleDeleteClick(id: string, event: MouseEvent) {\n    event.stopPropagation();\n    deleteConfirmId = id;\n  }\n\n  function handleConfirmDelete() {\n    if (deleteConfirmId) {\n      deleteConversation(deleteConfirmId);\n      deleteConfirmId = null;\n    }\n  }\n\n  function handleCancelDelete() {\n    deleteConfirmId = null;\n  }\n\n  function handleDeleteAllClick() {\n    showDeleteAllConfirm = true;\n  }\n\n  function handleConfirmDeleteAll() {\n    deleteAllConversations();\n    showDeleteAllConfirm = false;\n  }\n\n  function handleCancelDeleteAll() {\n    showDeleteAllConfirm = false;\n  }\n\n  function formatDate(timestamp: number): string {\n    const date = new Date(timestamp);\n    const now = new Date();\n    const diffDays = Math.floor(\n      (now.getTime() - date.getTime()) / (1000 * 60 * 60 * 24),\n    );\n\n    if (diffDays === 0) {\n      return date.toLocaleTimeString(\"en-US\", {\n        hour: \"2-digit\",\n        minute: \"2-digit\",\n      });\n    } else if (diffDays === 1) {\n      return \"Yesterday\";\n    } else if (diffDays < 7) {\n      return date.toLocaleDateString(\"en-US\", { weekday: \"short\" });\n    } else {\n      return date.toLocaleDateString(\"en-US\", {\n        month: \"short\",\n        day: \"numeric\",\n      });\n    }\n  }\n\n  function getLastAssistantStats(\n    conversation: (typeof conversationList)[0],\n  ): { ttftMs?: number; tps?: number } | null {\n    // Find the last assistant message with stats\n    for (let i = conversation.messages.length - 1; i >= 0; i--) {\n      const msg = conversation.messages[i];\n      if (msg.role === \"assistant\" && (msg.ttftMs || msg.tps)) {\n        return { ttftMs: msg.ttftMs, tps: msg.tps };\n      }\n    }\n    return null;\n  }\n\n  function formatModelName(modelId: string | null | undefined): string {\n    if (!modelId) return \"Unknown Model\";\n    const parts = modelId.split(\"/\");\n    const tail = parts[parts.length - 1] || modelId;\n    return tail || modelId;\n  }\n\n  function formatStrategy(\n    sharding: string | null | undefined,\n    instanceType: string | null | undefined,\n  ): string {\n    const shardLabel = sharding ?? \"Unknown\";\n    const typeLabel = instanceType ?? null;\n    return typeLabel ? `${shardLabel} (${typeLabel})` : shardLabel;\n  }\n\n  function getTaggedValue(obj: unknown): [string | null, unknown] {\n    if (!obj || typeof obj !== \"object\") return [null, null];\n    const keys = Object.keys(obj as Record<string, unknown>);\n    if (keys.length === 1) {\n      return [keys[0], (obj as Record<string, unknown>)[keys[0]]];\n    }\n    return [null, null];\n  }\n\n  function extractInstanceModelId(instanceWrapped: unknown): string | null {\n    const [, instance] = getTaggedValue(instanceWrapped);\n    if (!instance || typeof instance !== \"object\") return null;\n    const inst = instance as { shardAssignments?: { modelId?: string } };\n    return inst.shardAssignments?.modelId ?? null;\n  }\n\n  function describeInstance(instanceWrapped: unknown): {\n    sharding: string | null;\n    instanceType: string | null;\n  } {\n    const [instanceTag, instance] = getTaggedValue(instanceWrapped);\n    if (!instance || typeof instance !== \"object\") {\n      return { sharding: null, instanceType: null };\n    }\n\n    let instanceType: string | null = null;\n    if (instanceTag === \"MlxRingInstance\") instanceType = \"MLX Ring\";\n    else if (instanceTag === \"MlxJacclInstance\") instanceType = \"MLX RDMA\";\n\n    let sharding: string | null = null;\n    const inst = instance as {\n      shardAssignments?: { runnerToShard?: Record<string, unknown> };\n    };\n    const runnerToShard = inst.shardAssignments?.runnerToShard || {};\n    const firstShardWrapped = Object.values(runnerToShard)[0];\n    if (firstShardWrapped) {\n      const [shardTag] = getTaggedValue(firstShardWrapped);\n      if (shardTag === \"PipelineShardMetadata\") sharding = \"Pipeline\";\n      else if (shardTag === \"TensorShardMetadata\") sharding = \"Tensor\";\n      else if (shardTag === \"PrefillDecodeShardMetadata\")\n        sharding = \"Prefill/Decode\";\n    }\n\n    return { sharding, instanceType };\n  }\n\n  function resolveConversationInfo(\n    conversation: (typeof conversationList)[0],\n  ): { modelLabel: string; strategyLabel: string } {\n    // Attempt to match conversation model to an instance\n    let matchedInstance: unknown = null;\n    let modelId = conversation.modelId ?? null;\n\n    if (modelId) {\n      for (const [, instanceWrapper] of Object.entries(instanceData)) {\n        const candidate = extractInstanceModelId(instanceWrapper);\n        if (candidate === modelId) {\n          matchedInstance = instanceWrapper;\n          break;\n        }\n      }\n    }\n\n    // Fallback: use the first available instance if no explicit match\n    if (!matchedInstance) {\n      const firstInstance = Object.values(instanceData)[0];\n      if (firstInstance) {\n        matchedInstance = firstInstance;\n        modelId = modelId ?? extractInstanceModelId(firstInstance);\n      }\n    }\n\n    const instanceDetails = matchedInstance\n      ? describeInstance(matchedInstance)\n      : { sharding: null, instanceType: null };\n    const displayModel = modelId ?? conversation.modelId ?? null;\n    const sharding =\n      conversation.sharding ?? instanceDetails.sharding ?? \"Unknown\";\n    const instanceType =\n      conversation.instanceType ?? instanceDetails.instanceType;\n\n    return {\n      modelLabel: formatModelName(displayModel),\n      strategyLabel: formatStrategy(sharding, instanceType),\n    };\n  }\n</script>\n\n{#snippet sidebarContent()}\n  <!-- Header -->\n  <div class=\"p-4\">\n    <button\n      onclick={handleNewChat}\n      class=\"w-full flex items-center justify-center gap-2 py-2.5 px-4 bg-transparent border border-exo-yellow/30 text-exo-yellow text-xs font-mono tracking-wider uppercase hover:border-exo-yellow/50 transition-all cursor-pointer\"\n    >\n      <svg\n        class=\"w-4 h-4\"\n        fill=\"none\"\n        viewBox=\"0 0 24 24\"\n        stroke=\"currentColor\"\n      >\n        <path\n          stroke-linecap=\"round\"\n          stroke-linejoin=\"round\"\n          stroke-width=\"2\"\n          d=\"M12 4v16m8-8H4\"\n        />\n      </svg>\n      NEW CHAT\n    </button>\n  </div>\n\n  <!-- Search -->\n  <div class=\"px-4 py-3\">\n    <div class=\"relative\">\n      <svg\n        class=\"absolute left-3 top-1/2 -translate-y-1/2 w-3.5 h-3.5 text-white/50\"\n        fill=\"none\"\n        viewBox=\"0 0 24 24\"\n        stroke=\"currentColor\"\n      >\n        <path\n          stroke-linecap=\"round\"\n          stroke-linejoin=\"round\"\n          stroke-width=\"2\"\n          d=\"M21 21l-6-6m2-5a7 7 0 11-14 0 7 7 0 0114 0z\"\n        />\n      </svg>\n      <input\n        type=\"text\"\n        bind:value={searchQuery}\n        placeholder=\"Search conversations...\"\n        class=\"w-full bg-exo-black/40 border border-exo-medium-gray/30 rounded px-3 py-2 pl-9 text-xs text-white/90 placeholder:text-white/40 focus:outline-none focus:border-exo-yellow/30\"\n      />\n    </div>\n  </div>\n\n  <!-- Conversation List -->\n  <div class=\"flex-1 overflow-y-auto\">\n    {#if filteredConversations.length > 0}\n      <div class=\"py-2\">\n        <div class=\"px-4 py-2\">\n          <span\n            class=\"text-xs text-exo-light-gray font-mono tracking-wider uppercase\"\n          >\n            {searchQuery ? \"SEARCH RESULTS\" : \"CONVERSATIONS\"}\n          </span>\n        </div>\n\n        {#each filteredConversations as conversation (conversation.id)}\n          {@const info = resolveConversationInfo(conversation)}\n          <div class=\"px-2\">\n            {#if editingId === conversation.id}\n              <!-- Edit mode -->\n              <div\n                class=\"p-2 bg-transparent border border-exo-yellow/20 rounded mb-1\"\n              >\n                <input\n                  type=\"text\"\n                  bind:value={editingName}\n                  onkeydown={handleEditKeydown}\n                  class=\"w-full bg-exo-black/60 border border-exo-yellow/30 rounded px-2 py-1.5 text-xs text-exo-light-gray focus:outline-none focus:border-exo-yellow/50 mb-2\"\n                  autofocus\n                />\n                <div class=\"flex gap-2\">\n                  <button\n                    onclick={handleSaveEdit}\n                    class=\"flex-1 py-1.5 text-xs font-mono tracking-wider uppercase bg-transparent text-exo-yellow border border-exo-yellow/30 rounded hover:border-exo-yellow/50 cursor-pointer\"\n                  >\n                    SAVE\n                  </button>\n                  <button\n                    onclick={handleCancelEdit}\n                    class=\"flex-1 py-1.5 text-xs font-mono tracking-wider uppercase bg-exo-medium-gray/20 text-exo-light-gray border border-exo-medium-gray/30 rounded hover:bg-exo-medium-gray/30 cursor-pointer\"\n                  >\n                    CANCEL\n                  </button>\n                </div>\n              </div>\n            {:else if deleteConfirmId === conversation.id}\n              <!-- Delete confirmation -->\n              <div\n                class=\"p-2 bg-red-500/10 border border-red-500/30 rounded mb-1\"\n              >\n                <p class=\"text-xs text-red-400 mb-2\">\n                  Delete \"{conversation.name}\"?\n                </p>\n                <div class=\"flex gap-2\">\n                  <button\n                    onclick={handleConfirmDelete}\n                    class=\"flex-1 py-1.5 text-xs font-mono tracking-wider uppercase bg-red-500/20 text-red-400 border border-red-500/30 rounded hover:bg-red-500/30 cursor-pointer\"\n                  >\n                    DELETE\n                  </button>\n                  <button\n                    onclick={handleCancelDelete}\n                    class=\"flex-1 py-1.5 text-xs font-mono tracking-wider uppercase bg-exo-medium-gray/20 text-exo-light-gray border border-exo-medium-gray/30 rounded hover:bg-exo-medium-gray/30 cursor-pointer\"\n                  >\n                    CANCEL\n                  </button>\n                </div>\n              </div>\n            {:else}\n              <!-- Normal view -->\n              {@const stats = getLastAssistantStats(conversation)}\n              <div\n                role=\"button\"\n                tabindex=\"0\"\n                onclick={() => handleSelectConversation(conversation.id)}\n                onkeydown={(e) =>\n                  e.key === \"Enter\" &&\n                  handleSelectConversation(conversation.id)}\n                class=\"group w-full flex items-center justify-between p-2.5 rounded-lg mb-1 transition-all text-left cursor-pointer\n\t\t\t\t\t\t\t\t\t{activeId === conversation.id\n                  ? 'bg-exo-yellow/5 border border-exo-yellow/30'\n                  : 'hover:bg-white/[0.03] hover:border-white/10 border border-transparent'}\"\n              >\n                <div class=\"flex-1 min-w-0 pr-2\">\n                  <div\n                    class=\"text-sm font-medium truncate {activeId ===\n                    conversation.id\n                      ? 'text-exo-yellow'\n                      : 'text-white'}\"\n                  >\n                    {conversation.name}\n                  </div>\n                  <div class=\"text-xs text-white/60 mt-0.5\">\n                    {formatDate(conversation.updatedAt)}\n                  </div>\n                  <div class=\"text-xs text-exo-light-gray truncate\">\n                    {info.modelLabel}\n                  </div>\n                  {#if stats}\n                    <div class=\"text-xs text-white/70 font-mono mt-1\">\n                      {#if stats.ttftMs}<span class=\"text-white/50\">TTFT</span>\n                        <span class=\"text-exo-yellow/80\"\n                          >{stats.ttftMs.toFixed(0)}ms</span\n                        >{/if}{#if stats.ttftMs && stats.tps}<span\n                          class=\"text-white/30 mx-1.5\">·</span\n                        >{/if}{#if stats.tps}<span class=\"text-exo-yellow/80\"\n                          >{stats.tps.toFixed(1)}</span\n                        >\n                        <span class=\"text-white/50\">tok/s</span>{/if}\n                    </div>\n                  {/if}\n                </div>\n\n                <div\n                  class=\"flex items-center gap-1 opacity-0 group-hover:opacity-100 transition-opacity\"\n                >\n                  <button\n                    type=\"button\"\n                    onclick={(e) =>\n                      handleStartEdit(conversation.id, conversation.name, e)}\n                    class=\"p-1 text-exo-light-gray hover:text-exo-yellow transition-colors cursor-pointer\"\n                    title=\"Rename\"\n                  >\n                    <svg\n                      class=\"w-3 h-3\"\n                      fill=\"none\"\n                      viewBox=\"0 0 24 24\"\n                      stroke=\"currentColor\"\n                    >\n                      <path\n                        stroke-linecap=\"round\"\n                        stroke-linejoin=\"round\"\n                        stroke-width=\"2\"\n                        d=\"M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z\"\n                      />\n                    </svg>\n                  </button>\n                  <button\n                    type=\"button\"\n                    onclick={(e) => handleDeleteClick(conversation.id, e)}\n                    class=\"p-1 text-exo-light-gray hover:text-red-400 transition-colors cursor-pointer\"\n                    title=\"Delete\"\n                  >\n                    <svg\n                      class=\"w-3 h-3\"\n                      fill=\"none\"\n                      viewBox=\"0 0 24 24\"\n                      stroke=\"currentColor\"\n                    >\n                      <path\n                        stroke-linecap=\"round\"\n                        stroke-linejoin=\"round\"\n                        stroke-width=\"2\"\n                        d=\"M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16\"\n                      />\n                    </svg>\n                  </button>\n                </div>\n              </div>\n            {/if}\n          </div>\n        {/each}\n      </div>\n    {:else}\n      <div\n        class=\"flex flex-col items-center justify-center h-full p-4 text-center\"\n      >\n        <div\n          class=\"w-12 h-12 border border-exo-yellow/20 rounded-full flex items-center justify-center mb-3\"\n        >\n          <svg\n            class=\"w-6 h-6 text-exo-yellow/40\"\n            fill=\"none\"\n            viewBox=\"0 0 24 24\"\n            stroke=\"currentColor\"\n          >\n            <path\n              stroke-linecap=\"round\"\n              stroke-linejoin=\"round\"\n              stroke-width=\"1.5\"\n              d=\"M8 12h.01M12 12h.01M16 12h.01M21 12c0 4.418-4.03 8-9 8a9.863 9.863 0 01-4.255-.949L3 20l1.395-3.72C3.512 15.042 3 13.574 3 12c0-4.418 4.03-8 9-8s9 3.582 9 8z\"\n            />\n          </svg>\n        </div>\n        <p\n          class=\"text-xs text-white/70 font-mono tracking-wider uppercase mb-1\"\n        >\n          {searchQuery ? \"NO RESULTS\" : \"NO CONVERSATIONS\"}\n        </p>\n        <p class=\"text-sm text-white/50\">\n          {searchQuery ? \"Try a different search\" : \"Start a new chat to begin\"}\n        </p>\n      </div>\n    {/if}\n  </div>\n\n  <!-- Footer -->\n  <div class=\"p-3 border-t border-exo-yellow/10\">\n    {#if showDeleteAllConfirm}\n      <div class=\"bg-red-500/10 border border-red-500/30 rounded p-2 mb-2\">\n        <p class=\"text-xs text-red-400 text-center mb-2\">\n          Delete all {conversationList.length} conversations?\n        </p>\n        <div class=\"flex gap-2\">\n          <button\n            onclick={handleConfirmDeleteAll}\n            class=\"flex-1 py-1.5 text-xs font-mono tracking-wider uppercase bg-red-500/20 text-red-400 border border-red-500/30 rounded hover:bg-red-500/30 transition-colors cursor-pointer\"\n          >\n            DELETE ALL\n          </button>\n          <button\n            onclick={handleCancelDeleteAll}\n            class=\"flex-1 py-1.5 text-xs font-mono tracking-wider uppercase bg-exo-medium-gray/20 text-exo-light-gray border border-exo-medium-gray/30 rounded hover:bg-exo-medium-gray/30 transition-colors cursor-pointer\"\n          >\n            CANCEL\n          </button>\n        </div>\n      </div>\n    {:else if conversationList.length > 0}\n      <button\n        onclick={handleDeleteAllClick}\n        class=\"w-full flex items-center justify-center gap-2 py-1.5 text-sm font-mono tracking-wider uppercase text-white/70 hover:text-red-400 hover:bg-red-500/10 border border-transparent hover:border-red-500/20 rounded transition-all cursor-pointer\"\n      >\n        <svg\n          class=\"w-3.5 h-3.5\"\n          fill=\"none\"\n          viewBox=\"0 0 24 24\"\n          stroke=\"currentColor\"\n        >\n          <path\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n            stroke-width=\"2\"\n            d=\"M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16\"\n          />\n        </svg>\n        DELETE ALL CHATS\n      </button>\n    {/if}\n    <div\n      class=\"flex items-center justify-center gap-3 {conversationList.length >\n        0 && !showDeleteAllConfirm\n        ? 'mt-2'\n        : ''}\"\n    >\n      <button\n        type=\"button\"\n        onclick={toggleDebugMode}\n        class=\"p-1.5 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer\"\n        title=\"Toggle debug mode\"\n      >\n        <svg\n          class=\"w-4 h-4 {debugEnabled\n            ? 'text-exo-yellow'\n            : 'text-exo-medium-gray'}\"\n          fill=\"currentColor\"\n          viewBox=\"0 0 24 24\"\n        >\n          <path\n            d=\"M19 8h-1.81A6.002 6.002 0 0 0 12 2a6.002 6.002 0 0 0-5.19 3H5a1 1 0 0 0 0 2h1v2H5a1 1 0 0 0 0 2h1v2H5a1 1 0 0 0 0 2h1.81A6.002 6.002 0 0 0 12 22a6.002 6.002 0 0 0 5.19-3H19a1 1 0 0 0 0-2h-1v-2h1a1 1 0 0 0 0-2h-1v-2h1a1 1 0 1 0 0-2Zm-5 10.32V19a1 1 0 1 1-2 0v-.68a3.999 3.999 0 0 1-3-3.83V9.32a3.999 3.999 0 0 1 3-3.83V5a1 1 0 0 1 2 0v.49a3.999 3.999 0 0 1 3 3.83v5.17a3.999 3.999 0 0 1-3 3.83Z\"\n          />\n        </svg>\n      </button>\n      <div class=\"text-xs text-white/60 font-mono tracking-wider text-center\">\n        {conversationList.length} CONVERSATION{conversationList.length !== 1\n          ? \"S\"\n          : \"\"}\n      </div>\n      <button\n        type=\"button\"\n        onclick={toggleTopologyOnlyMode}\n        class=\"p-1.5 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer\"\n        title=\"Toggle topology only mode\"\n      >\n        <svg\n          class=\"w-4 h-4 {topologyOnlyEnabled\n            ? 'text-exo-yellow'\n            : 'text-exo-medium-gray'}\"\n          fill=\"none\"\n          viewBox=\"0 0 24 24\"\n          stroke=\"currentColor\"\n          stroke-width=\"2\"\n        >\n          <circle cx=\"12\" cy=\"5\" r=\"2\" fill=\"currentColor\" />\n          <circle cx=\"5\" cy=\"19\" r=\"2\" fill=\"currentColor\" />\n          <circle cx=\"19\" cy=\"19\" r=\"2\" fill=\"currentColor\" />\n          <path stroke-linecap=\"round\" d=\"M12 7v5m0 0l-5 5m5-5l5 5\" />\n        </svg>\n      </button>\n    </div>\n  </div>\n{/snippet}\n\n{#if isMobileDrawer}\n  <!-- Mobile drawer with overlay -->\n  {#if isOpen}\n    <!-- Overlay backdrop -->\n    <button\n      type=\"button\"\n      class=\"fixed inset-0 bg-black/60 backdrop-blur-sm z-40 md:hidden\"\n      onclick={() => onClose?.()}\n      aria-label=\"Close sidebar\"\n    ></button>\n    <!-- Drawer panel -->\n    <aside\n      class=\"fixed left-0 top-0 bottom-0 w-72 bg-exo-dark-gray border-r border-exo-yellow/10 z-50 flex flex-col md:hidden\"\n    >\n      {@render sidebarContent()}\n    </aside>\n  {/if}\n{:else}\n  <!-- Desktop sidebar -->\n  <aside\n    class=\"flex flex-col h-full bg-exo-dark-gray border-r border-exo-yellow/10 {className}\"\n  >\n    {@render sidebarContent()}\n  </aside>\n{/if}\n"
  },
  {
    "path": "dashboard/src/lib/components/ConnectionBanner.svelte",
    "content": "<script lang=\"ts\">\n  import { isConnected } from \"$lib/stores/app.svelte\";\n  import { slide } from \"svelte/transition\";\n\n  const connected = $derived(isConnected());\n</script>\n\n{#if !connected}\n  <div\n    transition:slide={{ duration: 200 }}\n    class=\"relative z-50 flex items-center justify-center gap-2 px-4 py-2 bg-red-950/80 border-b border-red-500/30\"\n    role=\"alert\"\n    aria-live=\"assertive\"\n  >\n    <div class=\"w-2 h-2 bg-red-500 rounded-full animate-pulse\"></div>\n    <span class=\"text-xs font-mono text-red-300 tracking-wider uppercase\">\n      Connection lost &mdash; Reconnecting to backend&hellip;\n    </span>\n  </div>\n{/if}\n"
  },
  {
    "path": "dashboard/src/lib/components/DeviceIcon.svelte",
    "content": "<script lang=\"ts\">\n  /**\n   * DeviceIcon — renders a device icon as an SVG <g> element.\n   * Uses the exact same proportional math as TopologyGraph.svelte\n   * so that devices look identical in both the topology view and\n   * the onboarding animation.\n   *\n   * Must be placed inside an <svg> element.\n   */\n\n  interface Props {\n    /** \"macbook pro\" | \"mac studio\" | \"mac mini\" etc. */\n    deviceType: string;\n    /** Center X coordinate in SVG space */\n    cx: number;\n    /** Center Y coordinate in SVG space */\n    cy: number;\n    /** Base sizing factor (equivalent to TopologyGraph's nodeRadius) */\n    size?: number;\n    /** RAM usage 0–100 */\n    ramPercent?: number;\n    /** Unique id suffix for clip-path ids */\n    uid?: string;\n  }\n\n  let {\n    deviceType,\n    cx,\n    cy,\n    size = 60,\n    ramPercent = 60,\n    uid = \"dev\",\n  }: Props = $props();\n\n  // Apple logo path — same constant used by TopologyGraph\n  const APPLE_LOGO_PATH =\n    \"M788.1 340.9c-5.8 4.5-108.2 62.2-108.2 190.5 0 148.4 130.3 200.9 134.2 202.2-.6 3.2-20.7 71.9-68.7 141.9-42.8 61.6-87.5 123.1-155.5 123.1s-85.5-39.5-164-39.5c-76.5 0-103.7 40.8-165.9 40.8s-105.6-57-155.5-127C46.7 790.7 0 663 0 541.8c0-194.4 126.4-297.5 250.8-297.5 66.1 0 121.2 43.4 162.7 43.4 39.5 0 101.1-46 176.3-46 28.5 0 130.9 2.6 198.3 99.2zm-234-181.5c31.1-36.9 53.1-88.1 53.1-139.3 0-7.1-.6-14.3-1.9-20.1-50.6 1.9-110.8 33.7-147.1 75.8-28.5 32.4-55.1 83.6-55.1 135.5 0 7.8 1.3 15.6 1.9 18.1 3.2.6 8.4 1.3 13.6 1.3 45.4 0 102.5-30.4 135.5-71.3z\";\n  const LOGO_NATIVE_WIDTH = 814;\n  const LOGO_NATIVE_HEIGHT = 1000;\n\n  const wireColor = \"rgba(179,179,179,0.8)\";\n  const strokeWidth = 1.5;\n\n  const modelLower = $derived(deviceType.toLowerCase());\n\n  // ── Mac Studio dimensions (same ratios as TopologyGraph) ──\n  const studioW = $derived(size * 1.25);\n  const studioH = $derived(size * 0.85);\n  const studioX = $derived(cx - studioW / 2);\n  const studioY = $derived(cy - studioH / 2);\n  const studioCorner = 4;\n  const studioTopH = $derived(studioH * 0.15);\n\n  // Studio front panel details\n  const studioSlotH = $derived(studioH * 0.14);\n  const studioVSlotW = $derived(studioW * 0.05);\n  const studioVSlotY = $derived(\n    studioY + studioTopH + (studioH - studioTopH) * 0.6,\n  );\n  const studioVSlot1X = $derived(studioX + studioW * 0.18);\n  const studioVSlot2X = $derived(studioX + studioW * 0.28);\n  const studioHSlotW = $derived(studioW * 0.2);\n  const studioHSlotX = $derived(studioX + studioW * 0.5 - studioHSlotW / 2);\n\n  // Studio memory fill\n  const studioMemTotalH = $derived(studioH - studioTopH);\n  const studioMemH = $derived((ramPercent / 100) * studioMemTotalH);\n\n  // ── MacBook dimensions (same ratios as TopologyGraph) ──\n  const mbW = $derived((size * 1.6 * 0.85) / 1.15);\n  const mbH = $derived(size * 0.85);\n  const mbX = $derived(cx - mbW / 2);\n  const mbY = $derived(cy - mbH / 2);\n\n  const mbScreenH = $derived(mbH * 0.7);\n  const mbBaseH = $derived(mbH * 0.3);\n  const mbScreenW = $derived(mbW * 0.85);\n  const mbScreenX = $derived(cx - mbScreenW / 2);\n  const mbBezel = 3;\n\n  // MacBook memory fill\n  const mbMemTotalH = $derived(mbScreenH - mbBezel * 2);\n  const mbMemH = $derived((ramPercent / 100) * mbMemTotalH);\n\n  // Apple logo sizing\n  const mbLogoTargetH = $derived(mbScreenH * 0.22);\n  const mbLogoScale = $derived(mbLogoTargetH / LOGO_NATIVE_HEIGHT);\n  const mbLogoX = $derived(cx - (LOGO_NATIVE_WIDTH * mbLogoScale) / 2);\n  const mbLogoY = $derived(\n    mbY + mbScreenH / 2 - (LOGO_NATIVE_HEIGHT * mbLogoScale) / 2,\n  );\n\n  // MacBook base (trapezoidal)\n  const mbBaseY = $derived(mbY + mbScreenH);\n  const mbBaseTopW = $derived(mbScreenW);\n  const mbBaseBottomW = $derived(mbW);\n  const mbBaseTopX = $derived(cx - mbBaseTopW / 2);\n  const mbBaseBottomX = $derived(cx - mbBaseBottomW / 2);\n\n  // Keyboard\n  const mbKbX = $derived(mbBaseTopX + 6);\n  const mbKbY = $derived(mbBaseY + 3);\n  const mbKbW = $derived(mbBaseTopW - 12);\n  const mbKbH = $derived(mbBaseH * 0.55);\n\n  // Trackpad\n  const mbTpW = $derived(mbBaseTopW * 0.4);\n  const mbTpX = $derived(cx - mbTpW / 2);\n  const mbTpY = $derived(mbBaseY + mbKbH + 5);\n  const mbTpH = $derived(mbBaseH * 0.3);\n\n  // Clip IDs\n  const screenClipId = $derived(`di-screen-${uid}`);\n  const studioClipId = $derived(`di-studio-${uid}`);\n</script>\n\n{#if modelLower === \"mac studio\" || modelLower === \"mac mini\"}\n  <!-- Mac Studio / Mac Mini -->\n  <defs>\n    <clipPath id={studioClipId}>\n      <rect\n        x={studioX}\n        y={studioY + studioTopH}\n        width={studioW}\n        height={studioH - studioTopH}\n        rx={studioCorner - 1}\n      />\n    </clipPath>\n  </defs>\n\n  <!-- Main body -->\n  <rect\n    x={studioX}\n    y={studioY}\n    width={studioW}\n    height={studioH}\n    rx={studioCorner}\n    fill=\"#1a1a1a\"\n    stroke={wireColor}\n    stroke-width={strokeWidth}\n  />\n\n  <!-- Memory fill -->\n  {#if ramPercent > 0}\n    <rect\n      x={studioX}\n      y={studioY + studioTopH + (studioMemTotalH - studioMemH)}\n      width={studioW}\n      height={studioMemH}\n      fill=\"rgba(255,215,0,0.75)\"\n      clip-path=\"url(#{studioClipId})\"\n    />\n  {/if}\n\n  <!-- Top surface divider -->\n  <line\n    x1={studioX}\n    y1={studioY + studioTopH}\n    x2={studioX + studioW}\n    y2={studioY + studioTopH}\n    stroke=\"rgba(179,179,179,0.3)\"\n    stroke-width=\"0.5\"\n  />\n\n  <!-- Front panel: vertical slots -->\n  <rect\n    x={studioVSlot1X - studioVSlotW / 2}\n    y={studioVSlotY}\n    width={studioVSlotW}\n    height={studioSlotH}\n    fill=\"rgba(0,0,0,0.35)\"\n    rx=\"1.5\"\n  />\n  <rect\n    x={studioVSlot2X - studioVSlotW / 2}\n    y={studioVSlotY}\n    width={studioVSlotW}\n    height={studioSlotH}\n    fill=\"rgba(0,0,0,0.35)\"\n    rx=\"1.5\"\n  />\n\n  <!-- Horizontal slot (SD card) -->\n  <rect\n    x={studioHSlotX}\n    y={studioVSlotY}\n    width={studioHSlotW}\n    height={studioSlotH * 0.6}\n    fill=\"rgba(0,0,0,0.35)\"\n    rx=\"1\"\n  />\n{:else}\n  <!-- MacBook Pro -->\n  <defs>\n    <clipPath id={screenClipId}>\n      <rect\n        x={mbScreenX + mbBezel}\n        y={mbY + mbBezel}\n        width={mbScreenW - mbBezel * 2}\n        height={mbScreenH - mbBezel * 2}\n        rx=\"2\"\n      />\n    </clipPath>\n  </defs>\n\n  <!-- Screen outer frame -->\n  <rect\n    x={mbScreenX}\n    y={mbY}\n    width={mbScreenW}\n    height={mbScreenH}\n    rx=\"3\"\n    fill=\"#1a1a1a\"\n    stroke={wireColor}\n    stroke-width={strokeWidth}\n  />\n\n  <!-- Screen inner (dark) -->\n  <rect\n    x={mbScreenX + mbBezel}\n    y={mbY + mbBezel}\n    width={mbScreenW - mbBezel * 2}\n    height={mbScreenH - mbBezel * 2}\n    rx=\"2\"\n    fill=\"#0a0a12\"\n  />\n\n  <!-- Memory fill on screen -->\n  {#if ramPercent > 0}\n    <rect\n      x={mbScreenX + mbBezel}\n      y={mbY + mbBezel + (mbMemTotalH - mbMemH)}\n      width={mbScreenW - mbBezel * 2}\n      height={mbMemH}\n      fill=\"rgba(255,215,0,0.85)\"\n      clip-path=\"url(#{screenClipId})\"\n    />\n  {/if}\n\n  <!-- Apple logo -->\n  <path\n    d={APPLE_LOGO_PATH}\n    transform=\"translate({mbLogoX}, {mbLogoY}) scale({mbLogoScale})\"\n    fill=\"#FFFFFF\"\n    opacity=\"0.9\"\n  />\n\n  <!-- Keyboard base (trapezoidal) -->\n  <path\n    d=\"M {mbBaseTopX} {mbBaseY} L {mbBaseTopX +\n      mbBaseTopW} {mbBaseY} L {mbBaseBottomX + mbBaseBottomW} {mbBaseY +\n      mbBaseH} L {mbBaseBottomX} {mbBaseY + mbBaseH} Z\"\n    fill=\"#2c2c2c\"\n    stroke={wireColor}\n    stroke-width=\"1\"\n  />\n\n  <!-- Keyboard area -->\n  <rect\n    x={mbKbX}\n    y={mbKbY}\n    width={mbKbW}\n    height={mbKbH}\n    fill=\"rgba(0,0,0,0.2)\"\n    rx=\"2\"\n  />\n\n  <!-- Trackpad -->\n  <rect\n    x={mbTpX}\n    y={mbTpY}\n    width={mbTpW}\n    height={mbTpH}\n    fill=\"rgba(255,255,255,0.08)\"\n    rx=\"2\"\n  />\n{/if}\n"
  },
  {
    "path": "dashboard/src/lib/components/FamilyLogos.svelte",
    "content": "<script lang=\"ts\">\n  type FamilyLogoProps = {\n    family: string;\n    class?: string;\n  };\n\n  let { family, class: className = \"\" }: FamilyLogoProps = $props();\n</script>\n\n{#if family === \"favorites\"}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z\"\n    />\n  </svg>\n{:else if family === \"recents\"}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M13 3a9 9 0 0 0-9 9H1l3.89 3.89.07.14L9 12H6c0-3.87 3.13-7 7-7s7 3.13 7 7-3.13 7-7 7c-1.93 0-3.68-.79-4.94-2.06l-1.42 1.42A8.954 8.954 0 0 0 13 21a9 9 0 0 0 0-18zm-1 5v5l4.28 2.54.72-1.21-3.5-2.08V8H12z\"\n    />\n  </svg>\n{:else if family === \"llama\" || family === \"meta\"}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M6.915 4.03c-1.968 0-3.683 1.28-4.871 3.113C.704 9.208 0 11.883 0 14.449c0 .706.07 1.369.21 1.973a6.624 6.624 0 0 0 .265.86 5.297 5.297 0 0 0 .371.761c.696 1.159 1.818 1.927 3.593 1.927 1.497 0 2.633-.671 3.965-2.444.76-1.012 1.144-1.626 2.663-4.32l.756-1.339.186-.325c.061.1.121.196.183.3l2.152 3.595c.724 1.21 1.665 2.556 2.47 3.314 1.046.987 1.992 1.22 3.06 1.22 1.075 0 1.876-.355 2.455-.843a3.743 3.743 0 0 0 .81-.973c.542-.939.861-2.127.861-3.745 0-2.72-.681-5.357-2.084-7.45-1.282-1.912-2.957-2.93-4.716-2.93-1.047 0-2.088.467-3.053 1.308-.652.57-1.257 1.29-1.82 2.05-.69-.875-1.335-1.547-1.958-2.056-1.182-.966-2.315-1.303-3.454-1.303zm10.16 2.053c1.147 0 2.188.758 2.992 1.999 1.132 1.748 1.647 4.195 1.647 6.4 0 1.548-.368 2.9-1.839 2.9-.58 0-1.027-.23-1.664-1.004-.496-.601-1.343-1.878-2.832-4.358l-.617-1.028a44.908 44.908 0 0 0-1.255-1.98c.07-.109.141-.224.211-.327 1.12-1.667 2.118-2.602 3.358-2.602zm-10.201.553c1.265 0 2.058.791 2.675 1.446.307.327.737.871 1.234 1.579l-1.02 1.566c-.757 1.163-1.882 3.017-2.837 4.338-1.191 1.649-1.81 1.817-2.486 1.817-.524 0-1.038-.237-1.383-.794-.263-.426-.464-1.13-.464-2.046 0-2.221.63-4.535 1.66-6.088.454-.687.964-1.226 1.533-1.533a2.264 2.264 0 0 1 1.088-.285z\"\n    />\n  </svg>\n{:else if family === \"qwen\"}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M12.604 1.34c.393.69.784 1.382 1.174 2.075a.18.18 0 00.157.091h5.552c.174 0 .322.11.446.327l1.454 2.57c.19.337.24.478.024.837-.26.43-.513.864-.76 1.3l-.367.658c-.106.196-.223.28-.04.512l2.652 4.637c.172.301.111.494-.043.77-.437.785-.882 1.564-1.335 2.34-.159.272-.352.375-.68.37-.777-.016-1.552-.01-2.327.016a.099.099 0 00-.081.05 575.097 575.097 0 01-2.705 4.74c-.169.293-.38.363-.725.364-.997.003-2.002.004-3.017.002a.537.537 0 01-.465-.271l-1.335-2.323a.09.09 0 00-.083-.049H4.982c-.285.03-.553-.001-.805-.092l-1.603-2.77a.543.543 0 01-.002-.54l1.207-2.12a.198.198 0 000-.197 550.951 550.951 0 01-1.875-3.272l-.79-1.395c-.16-.31-.173-.496.095-.965.465-.813.927-1.625 1.387-2.436.132-.234.304-.334.584-.335a338.3 338.3 0 012.589-.001.124.124 0 00.107-.063l2.806-4.895a.488.488 0 01.422-.246c.524-.001 1.053 0 1.583-.006L11.704 1c.341-.003.724.032.9.34zm-3.432.403a.06.06 0 00-.052.03L6.254 6.788a.157.157 0 01-.135.078H3.253c-.056 0-.07.025-.041.074l5.81 10.156c.025.042.013.062-.034.063l-2.795.015a.218.218 0 00-.2.116l-1.32 2.31c-.044.078-.021.118.068.118l5.716.008c.046 0 .08.02.104.061l1.403 2.454c.046.081.092.082.139 0l5.006-8.76.783-1.382a.055.055 0 01.096 0l1.424 2.53a.122.122 0 00.107.062l2.763-.02a.04.04 0 00.035-.02.041.041 0 000-.04l-2.9-5.086a.108.108 0 010-.113l.293-.507 1.12-1.977c.024-.041.012-.062-.035-.062H9.2c-.059 0-.073-.026-.043-.077l1.434-2.505a.107.107 0 000-.114L9.225 1.774a.06.06 0 00-.053-.031zm6.29 8.02c.046 0 .058.02.034.06l-.832 1.465-2.613 4.585a.056.056 0 01-.05.029.058.058 0 01-.05-.029L8.498 9.841c-.02-.034-.01-.052.028-.054l.216-.012 6.722-.012z\"\n    />\n  </svg>\n{:else if family === \"deepseek\"}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M23.748 4.482c-.254-.124-.364.113-.512.234-.051.039-.094.09-.137.136-.372.397-.806.657-1.373.626-.829-.046-1.537.214-2.163.848-.133-.782-.575-1.248-1.247-1.548-.352-.156-.708-.311-.955-.65-.172-.241-.219-.51-.305-.774-.055-.16-.11-.323-.293-.35-.2-.031-.278.136-.356.276-.313.572-.434 1.202-.422 1.84.027 1.436.633 2.58 1.838 3.393.137.093.172.187.129.323-.082.28-.18.552-.266.833-.055.179-.137.217-.329.14a5.526 5.526 0 01-1.736-1.18c-.857-.828-1.631-1.742-2.597-2.458a11.365 11.365 0 00-.689-.471c-.985-.957.13-1.743.388-1.836.27-.098.093-.432-.779-.428-.872.004-1.67.295-2.687.684a3.055 3.055 0 01-.465.137 9.597 9.597 0 00-2.883-.102c-1.885.21-3.39 1.102-4.497 2.623C.082 8.606-.231 10.684.152 12.85c.403 2.284 1.569 4.175 3.36 5.653 1.858 1.533 3.997 2.284 6.438 2.14 1.482-.085 3.133-.284 4.994-1.86.47.234.962.327 1.78.397.63.059 1.236-.03 1.705-.128.735-.156.684-.837.419-.961-2.155-1.004-1.682-.595-2.113-.926 1.096-1.296 2.746-2.642 3.392-7.003.05-.347.007-.565 0-.845-.004-.17.035-.237.23-.256a4.173 4.173 0 001.545-.475c1.396-.763 1.96-2.015 2.093-3.517.02-.23-.004-.467-.247-.588zM11.581 18c-2.089-1.642-3.102-2.183-3.52-2.16-.392.024-.321.471-.235.763.09.288.207.486.371.739.114.167.192.416-.113.603-.673.416-1.842-.14-1.897-.167-1.361-.802-2.5-1.86-3.301-3.307-.774-1.393-1.224-2.887-1.298-4.482-.02-.386.093-.522.477-.592a4.696 4.696 0 011.529-.039c2.132.312 3.946 1.265 5.468 2.774.868.86 1.525 1.887 2.202 2.891.72 1.066 1.494 2.082 2.48 2.914.348.292.625.514.891.677-.802.09-2.14.11-3.054-.614zm1-6.44a.306.306 0 01.415-.287.302.302 0 01.2.288.306.306 0 01-.31.307.303.303 0 01-.304-.308zm3.11 1.596c-.2.081-.399.151-.59.16a1.245 1.245 0 01-.798-.254c-.274-.23-.47-.358-.552-.758a1.73 1.73 0 01.016-.588c.07-.327-.008-.537-.239-.727-.187-.156-.426-.199-.688-.199a.559.559 0 01-.254-.078c-.11-.054-.2-.19-.114-.358.028-.054.16-.186.192-.21.356-.202.767-.136 1.146.016.352.144.618.408 1.001.782.391.451.462.576.685.914.176.265.336.537.445.848.067.195-.019.354-.25.452z\"\n    />\n  </svg>\n{:else if family === \"openai\" || family === \"gpt-oss\"}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M22.2819 9.8211a5.9847 5.9847 0 0 0-.5157-4.9108 6.0462 6.0462 0 0 0-6.5098-2.9A6.0651 6.0651 0 0 0 4.9807 4.1818a5.9847 5.9847 0 0 0-3.9977 2.9 6.0462 6.0462 0 0 0 .7427 7.0966 5.98 5.98 0 0 0 .511 4.9107 6.051 6.051 0 0 0 6.5146 2.9001A5.9847 5.9847 0 0 0 13.2599 24a6.0557 6.0557 0 0 0 5.7718-4.2058 5.9894 5.9894 0 0 0 3.9977-2.9001 6.0557 6.0557 0 0 0-.7475-7.0729zm-9.022 12.6081a4.4755 4.4755 0 0 1-2.8764-1.0408l.1419-.0804 4.7783-2.7582a.7948.7948 0 0 0 .3927-.6813v-6.7369l2.02 1.1686a.071.071 0 0 1 .038.052v5.5826a4.504 4.504 0 0 1-4.4945 4.4944zm-9.6607-4.1254a4.4708 4.4708 0 0 1-.5346-3.0137l.142.0852 4.783 2.7582a.7712.7712 0 0 0 .7806 0l5.8428-3.3685v2.3324a.0804.0804 0 0 1-.0332.0615L9.74 19.9502a4.4992 4.4992 0 0 1-6.1408-1.6464zM2.3408 7.8956a4.485 4.485 0 0 1 2.3655-1.9728V11.6a.7664.7664 0 0 0 .3879.6765l5.8144 3.3543-2.0201 1.1685a.0757.0757 0 0 1-.071 0l-4.8303-2.7865A4.504 4.504 0 0 1 2.3408 7.872zm16.5963 3.8558L13.1038 8.364 15.1192 7.2a.0757.0757 0 0 1 .071 0l4.8303 2.7913a4.4944 4.4944 0 0 1-.6765 8.1042v-5.6772a.79.79 0 0 0-.407-.667zm2.0107-3.0231l-.142-.0852-4.7735-2.7818a.7759.7759 0 0 0-.7854 0L9.409 9.2297V6.8974a.0662.0662 0 0 1 .0284-.0615l4.8303-2.7866a4.4992 4.4992 0 0 1 6.6802 4.66zM8.3065 12.863l-2.02-1.1638a.0804.0804 0 0 1-.038-.0567V6.0742a4.4992 4.4992 0 0 1 7.3757-3.4537l-.142.0805L8.704 5.459a.7948.7948 0 0 0-.3927.6813zm1.0976-2.3654l2.602-1.4998 2.6069 1.4998v2.9994l-2.5974 1.4997-2.6067-1.4997Z\"\n    />\n  </svg>\n{:else if family === \"glm\"}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M11.991 23.503a.24.24 0 00-.244.248.24.24 0 00.244.249.24.24 0 00.245-.249.24.24 0 00-.22-.247l-.025-.001zM9.671 5.365a1.697 1.697 0 011.099 2.132l-.071.172-.016.04-.018.054c-.07.16-.104.32-.104.498-.035.71.47 1.279 1.186 1.314h.366c1.309.053 2.338 1.173 2.286 2.523-.052 1.332-1.152 2.38-2.478 2.327h-.174c-.715.018-1.274.64-1.239 1.368 0 .124.018.23.053.337.209.373.54.658.96.8.75.23 1.517-.125 1.9-.782l.018-.035c.402-.64 1.17-.96 1.92-.711.854.284 1.378 1.226 1.099 2.167a1.661 1.661 0 01-2.077 1.102 1.711 1.711 0 01-.907-.711l-.017-.035c-.2-.323-.463-.58-.851-.711l-.056-.018a1.646 1.646 0 00-1.954.746 1.66 1.66 0 01-1.065.764 1.677 1.677 0 01-1.989-1.279c-.209-.906.332-1.83 1.257-2.043a1.51 1.51 0 01.296-.035h.018c.68-.071 1.151-.622 1.116-1.333a1.307 1.307 0 00-.227-.693 2.515 2.515 0 01-.366-1.403 2.39 2.39 0 01.366-1.208c.14-.195.21-.444.227-.693.018-.71-.506-1.261-1.186-1.332l-.07-.018a1.43 1.43 0 01-.299-.07l-.05-.019a1.7 1.7 0 01-1.047-2.114 1.68 1.68 0 012.094-1.101zm-5.575 10.11c.26-.264.639-.367.994-.27.355.096.633.379.728.74.095.362-.007.748-.267 1.013-.402.41-1.053.41-1.455 0a1.062 1.062 0 010-1.482zm14.845-.294c.359-.09.738.024.992.297.254.274.344.665.237 1.025-.107.36-.396.634-.756.718-.551.128-1.1-.22-1.23-.781a1.05 1.05 0 01.757-1.26zm-.064-4.39c.314.32.49.753.49 1.206 0 .452-.176.886-.49 1.206-.315.32-.74.5-1.185.5-.444 0-.87-.18-1.184-.5a1.727 1.727 0 010-2.412 1.654 1.654 0 012.369 0zm-11.243.163c.364.484.447 1.128.218 1.691a1.665 1.665 0 01-2.188.923c-.855-.36-1.26-1.358-.907-2.228a1.68 1.68 0 011.33-1.038c.593-.08 1.183.169 1.547.652zm11.545-4.221c.368 0 .708.2.892.524.184.324.184.724 0 1.048a1.026 1.026 0 01-.892.524c-.568 0-1.03-.47-1.03-1.048 0-.579.462-1.048 1.03-1.048zm-14.358 0c.368 0 .707.2.891.524.184.324.184.724 0 1.048a1.026 1.026 0 01-.891.524c-.569 0-1.03-.47-1.03-1.048 0-.579.461-1.048 1.03-1.048zm10.031-1.475c.925 0 1.675.764 1.675 1.706s-.75 1.705-1.675 1.705-1.674-.763-1.674-1.705c0-.942.75-1.706 1.674-1.706zm-2.626-.684c.362-.082.653-.356.761-.718a1.062 1.062 0 00-.238-1.028 1.017 1.017 0 00-.996-.294c-.547.14-.881.7-.752 1.257.13.558.675.907 1.225.783zm0 16.876c.359-.087.644-.36.75-.72a1.062 1.062 0 00-.237-1.019 1.018 1.018 0 00-.985-.301 1.037 1.037 0 00-.762.717c-.108.361-.017.754.239 1.028.245.263.606.377.953.305l.043-.01zM17.19 3.5a.631.631 0 00.628-.64c0-.355-.279-.64-.628-.64a.631.631 0 00-.628.64c0 .355.28.64.628.64zm-10.38 0a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64a.631.631 0 00-.628.64c0 .355.279.64.628.64zm-5.182 7.852a.631.631 0 00-.628.64c0 .354.28.639.628.639a.63.63 0 00.627-.606l.001-.034a.62.62 0 00-.628-.64zm5.182 9.13a.631.631 0 00-.628.64c0 .355.279.64.628.64a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64zm10.38.018a.631.631 0 00-.628.64c0 .355.28.64.628.64a.631.631 0 00.628-.64c0-.355-.279-.64-.628-.64zm5.182-9.148a.631.631 0 00-.628.64c0 .354.279.639.628.639a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64zm-.384-4.992a.24.24 0 00.244-.249.24.24 0 00-.244-.249.24.24 0 00-.244.249c0 .142.122.249.244.249zM11.991.497a.24.24 0 00.245-.248A.24.24 0 0011.99 0a.24.24 0 00-.244.249c0 .133.108.236.223.247l.021.001zM2.011 6.36a.24.24 0 00.245-.249.24.24 0 00-.244-.249.24.24 0 00-.244.249.24.24 0 00.244.249zm0 11.263a.24.24 0 00-.243.248.24.24 0 00.244.249.24.24 0 00.244-.249.252.252 0 00-.244-.248zm19.995-.018a.24.24 0 00-.245.248.24.24 0 00.245.25.24.24 0 00.244-.25.252.252 0 00-.244-.248z\"\n    />\n  </svg>\n{:else if family === \"minimax\"}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M16.278 2c1.156 0 2.093.927 2.093 2.07v12.501a.74.74 0 00.744.709.74.74 0 00.743-.709V9.099a2.06 2.06 0 012.071-2.049A2.06 2.06 0 0124 9.1v6.561a.649.649 0 01-.652.645.649.649 0 01-.653-.645V9.1a.762.762 0 00-.766-.758.762.762 0 00-.766.758v7.472a2.037 2.037 0 01-2.048 2.026 2.037 2.037 0 01-2.048-2.026v-12.5a.785.785 0 00-.788-.753.785.785 0 00-.789.752l-.001 15.904A2.037 2.037 0 0113.441 22a2.037 2.037 0 01-2.048-2.026V18.04c0-.356.292-.645.652-.645.36 0 .652.289.652.645v1.934c0 .263.142.506.372.638.23.131.514.131.744 0a.734.734 0 00.372-.638V4.07c0-1.143.937-2.07 2.093-2.07zm-5.674 0c1.156 0 2.093.927 2.093 2.07v11.523a.648.648 0 01-.652.645.648.648 0 01-.652-.645V4.07a.785.785 0 00-.789-.78.785.785 0 00-.789.78v14.013a2.06 2.06 0 01-2.07 2.048 2.06 2.06 0 01-2.071-2.048V9.1a.762.762 0 00-.766-.758.762.762 0 00-.766.758v3.8a2.06 2.06 0 01-2.071 2.049A2.06 2.06 0 010 12.9v-1.378c0-.357.292-.646.652-.646.36 0 .653.29.653.646V12.9c0 .418.343.757.766.757s.766-.339.766-.757V9.099a2.06 2.06 0 012.07-2.048 2.06 2.06 0 012.071 2.048v8.984c0 .419.343.758.767.758.423 0 .766-.339.766-.758V4.07c0-1.143.937-2.07 2.093-2.07z\"\n    />\n  </svg>\n{:else if family === \"kimi\"}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M19.738 5.776c.163-.209.306-.4.457-.585.07-.087.064-.153-.004-.244-.655-.861-.717-1.817-.34-2.787.283-.73.909-1.072 1.674-1.145.477-.045.945.004 1.379.236.57.305.902.77 1.01 1.412.086.512.07 1.012-.075 1.508-.257.878-.888 1.333-1.753 1.448-.718.096-1.446.108-2.17.157-.056.004-.113 0-.178 0z\"\n    />\n    <path\n      d=\"M17.962 1.844h-4.326l-3.425 7.81H5.369V1.878H1.5V22h3.87v-8.477h6.824a3.025 3.025 0 002.743-1.75V22h3.87v-8.477a3.87 3.87 0 00-3.588-3.86v-.01h-2.125a3.94 3.94 0 002.323-2.12l2.545-5.689z\"\n    />\n  </svg>\n{:else if family === \"flux\"}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M12 2L2 19h7.5l2.5-4.5L14.5 19H22L12 2zm0 4.5L16.5 17h-3l-1.5-2.7L10.5 17h-3L12 6.5z\"\n    />\n  </svg>\n{:else if family === \"qwen-image\"}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M12.604 1.34c.393.69.784 1.382 1.174 2.075a.18.18 0 00.157.091h5.552c.174 0 .322.11.446.327l1.454 2.57c.19.337.24.478.024.837-.26.43-.513.864-.76 1.3l-.367.658c-.106.196-.223.28-.04.512l2.652 4.637c.172.301.111.494-.043.77-.437.785-.882 1.564-1.335 2.34-.159.272-.352.375-.68.37-.777-.016-1.552-.01-2.327.016a.099.099 0 00-.081.05 575.097 575.097 0 01-2.705 4.74c-.169.293-.38.363-.725.364-.997.003-2.002.004-3.017.002a.537.537 0 01-.465-.271l-1.335-2.323a.09.09 0 00-.083-.049H4.982c-.285.03-.553-.001-.805-.092l-1.603-2.77a.543.543 0 01-.002-.54l1.207-2.12a.198.198 0 000-.197 550.951 550.951 0 01-1.875-3.272l-.79-1.395c-.16-.31-.173-.496.095-.965.465-.813.927-1.625 1.387-2.436.132-.234.304-.334.584-.335a338.3 338.3 0 012.589-.001.124.124 0 00.107-.063l2.806-4.895a.488.488 0 01.422-.246c.524-.001 1.053 0 1.583-.006L11.704 1c.341-.003.724.032.9.34zm-3.432.403a.06.06 0 00-.052.03L6.254 6.788a.157.157 0 01-.135.078H3.253c-.056 0-.07.025-.041.074l5.81 10.156c.025.042.013.062-.034.063l-2.795.015a.218.218 0 00-.2.116l-1.32 2.31c-.044.078-.021.118.068.118l5.716.008c.046 0 .08.02.104.061l1.403 2.454c.046.081.092.082.139 0l5.006-8.76.783-1.382a.055.055 0 01.096 0l1.424 2.53a.122.122 0 00.107.062l2.763-.02a.04.04 0 00.035-.02.041.041 0 000-.04l-2.9-5.086a.108.108 0 010-.113l.293-.507 1.12-1.977c.024-.041.012-.062-.035-.062H9.2c-.059 0-.073-.026-.043-.077l1.434-2.505a.107.107 0 000-.114L9.225 1.774a.06.06 0 00-.053-.031zm6.29 8.02c.046 0 .058.02.034.06l-.832 1.465-2.613 4.585a.056.056 0 01-.05.029.058.058 0 01-.05-.029L8.498 9.841c-.02-.034-.01-.052.028-.054l.216-.012 6.722-.012z\"\n    />\n  </svg>\n{:else if family === \"huggingface\"}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M12.025 1.13c-5.77 0-10.449 4.647-10.449 10.378 0 1.112.178 2.181.503 3.185.064-.222.203-.444.416-.577a.96.96 0 0 1 .524-.15c.293 0 .584.124.84.284.278.173.48.408.71.694.226.282.458.611.684.951v-.014c.017-.324.106-.622.264-.874s.403-.487.762-.543c.3-.047.596.06.787.203s.31.313.4.467c.15.257.212.468.233.542.01.026.653 1.552 1.657 2.54.616.605 1.01 1.223 1.082 1.912.055.537-.096 1.059-.38 1.572.637.121 1.294.187 1.967.187.657 0 1.298-.063 1.921-.178-.287-.517-.44-1.041-.384-1.581.07-.69.465-1.307 1.081-1.913 1.004-.987 1.647-2.513 1.657-2.539.021-.074.083-.285.233-.542.09-.154.208-.323.4-.467a1.08 1.08 0 0 1 .787-.203c.359.056.604.29.762.543s.247.55.265.874v.015c.225-.34.457-.67.683-.952.23-.286.432-.52.71-.694.257-.16.547-.284.84-.285a.97.97 0 0 1 .524.151c.228.143.373.388.43.625l.006.04a10.3 10.3 0 0 0 .534-3.273c0-5.731-4.678-10.378-10.449-10.378M8.327 6.583a1.5 1.5 0 0 1 .713.174 1.487 1.487 0 0 1 .617 2.013c-.183.343-.762-.214-1.102-.094-.38.134-.532.914-.917.71a1.487 1.487 0 0 1 .69-2.803m7.486 0a1.487 1.487 0 0 1 .689 2.803c-.385.204-.536-.576-.916-.71-.34-.12-.92.437-1.103.094a1.487 1.487 0 0 1 .617-2.013 1.5 1.5 0 0 1 .713-.174m-10.68 1.55a.96.96 0 1 1 0 1.921.96.96 0 0 1 0-1.92m13.838 0a.96.96 0 1 1 0 1.92.96.96 0 0 1 0-1.92M8.489 11.458c.588.01 1.965 1.157 3.572 1.164 1.607-.007 2.984-1.155 3.572-1.164.196-.003.305.12.305.454 0 .886-.424 2.328-1.563 3.202-.22-.756-1.396-1.366-1.63-1.32q-.011.001-.02.006l-.044.026-.01.008-.03.024q-.018.017-.035.036l-.032.04a1 1 0 0 0-.058.09l-.014.025q-.049.088-.11.19a1 1 0 0 1-.083.116 1.2 1.2 0 0 1-.173.18q-.035.029-.075.058a1.3 1.3 0 0 1-.251-.243 1 1 0 0 1-.076-.107c-.124-.193-.177-.363-.337-.444-.034-.016-.104-.008-.2.022q-.094.03-.216.087-.06.028-.125.063l-.13.074q-.067.04-.136.086a3 3 0 0 0-.135.096 3 3 0 0 0-.26.219 2 2 0 0 0-.12.121 2 2 0 0 0-.106.128l-.002.002a2 2 0 0 0-.09.132l-.001.001a1.2 1.2 0 0 0-.105.212q-.013.036-.024.073c-1.139-.875-1.563-2.317-1.563-3.203 0-.334.109-.457.305-.454m.836 10.354c.824-1.19.766-2.082-.365-3.194-1.13-1.112-1.789-2.738-1.789-2.738s-.246-.945-.806-.858-.97 1.499.202 2.362c1.173.864-.233 1.45-.685.64-.45-.812-1.683-2.896-2.322-3.295s-1.089-.175-.938.647 2.822 2.813 2.562 3.244-1.176-.506-1.176-.506-2.866-2.567-3.49-1.898.473 1.23 2.037 2.16c1.564.932 1.686 1.178 1.464 1.53s-3.675-2.511-4-1.297c-.323 1.214 3.524 1.567 3.287 2.405-.238.839-2.71-1.587-3.216-.642-.506.946 3.49 2.056 3.522 2.064 1.29.33 4.568 1.028 5.713-.624m5.349 0c-.824-1.19-.766-2.082.365-3.194 1.13-1.112 1.789-2.738 1.789-2.738s.246-.945.806-.858.97 1.499-.202 2.362c-1.173.864.233 1.45.685.64.451-.812 1.683-2.896 2.322-3.295s1.089-.175.938.647-2.822 2.813-2.562 3.244 1.176-.506 1.176-.506 2.866-2.567 3.49-1.898-.473 1.23-2.037 2.16c-1.564.932-1.686 1.178-1.464 1.53s3.675-2.511 4-1.297c.323 1.214-3.524 1.567-3.287 2.405.238.839 2.71-1.587 3.216-.642.506.946-3.49 2.056-3.522 2.064-1.29.33-4.568 1.028-5.713-.624\"\n    />\n  </svg>\n{:else if family === \"step\"}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M22.012 0h1.032v.927H24v.968h-.956V3.78h-1.032V1.896h-1.878v-.97h1.878V0zM2.6 12.371V1.87h.969v10.502h-.97zm10.423.66h10.95v.918h-6.208v9.579h-4.742V13.03zM5.629 3.333v12.356H0v4.51h10.386V8L20.859 8l-.003-4.668-15.227.001z\"\n    />\n  </svg>\n{:else}\n  <svg class=\"w-6 h-6 {className}\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n    <path\n      d=\"M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 15l-5-5 1.41-1.41L10 14.17l7.59-7.59L19 8l-9 9z\"\n    />\n  </svg>\n{/if}\n"
  },
  {
    "path": "dashboard/src/lib/components/FamilySidebar.svelte",
    "content": "<script lang=\"ts\">\n  import FamilyLogos from \"./FamilyLogos.svelte\";\n\n  type FamilySidebarProps = {\n    families: string[];\n    selectedFamily: string | null;\n    hasFavorites: boolean;\n    hasRecents: boolean;\n    onSelect: (family: string | null) => void;\n  };\n\n  let {\n    families,\n    selectedFamily,\n    hasFavorites,\n    hasRecents,\n    onSelect,\n  }: FamilySidebarProps = $props();\n\n  // Family display names\n  const familyNames: Record<string, string> = {\n    favorites: \"Favorites\",\n    recents: \"Recent\",\n    huggingface: \"Hub\",\n    llama: \"Meta\",\n    qwen: \"Qwen\",\n    deepseek: \"DeepSeek\",\n    \"gpt-oss\": \"OpenAI\",\n    glm: \"GLM\",\n    minimax: \"MiniMax\",\n    kimi: \"Kimi\",\n    flux: \"FLUX\",\n    \"qwen-image\": \"Qwen Img\",\n  };\n\n  function getFamilyName(family: string): string {\n    return (\n      familyNames[family] || family.charAt(0).toUpperCase() + family.slice(1)\n    );\n  }\n</script>\n\n<div\n  class=\"flex flex-col gap-1 py-2 px-1 border-r border-exo-yellow/10 bg-exo-medium-gray/30 min-w-[72px] sm:min-w-[64px] overflow-y-auto scrollbar-hide\"\n>\n  <!-- All models (no filter) -->\n  <button\n    type=\"button\"\n    onclick={() => onSelect(null)}\n    class=\"group flex flex-col items-center justify-center p-2 sm:p-2 rounded transition-all duration-200 cursor-pointer min-h-[44px] sm:min-h-0 {selectedFamily ===\n    null\n      ? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'\n      : 'hover:bg-white/5 border-l-2 border-transparent'}\"\n    title=\"All models\"\n  >\n    <svg\n      class=\"w-5 h-5 {selectedFamily === null\n        ? 'text-exo-yellow'\n        : 'text-white/50 group-hover:text-white/70'}\"\n      viewBox=\"0 0 24 24\"\n      fill=\"currentColor\"\n    >\n      <path\n        d=\"M4 8h4V4H4v4zm6 12h4v-4h-4v4zm-6 0h4v-4H4v4zm0-6h4v-4H4v4zm6 0h4v-4h-4v4zm6-10v4h4V4h-4zm-6 4h4V4h-4v4zm6 6h4v-4h-4v4zm0 6h4v-4h-4v4z\"\n      />\n    </svg>\n    <span\n      class=\"text-[9px] font-mono mt-0.5 {selectedFamily === null\n        ? 'text-exo-yellow'\n        : 'text-white/40 group-hover:text-white/60'}\">All</span\n    >\n  </button>\n\n  <!-- Favorites (only show if has favorites) -->\n  {#if hasFavorites}\n    <button\n      type=\"button\"\n      onclick={() => onSelect(\"favorites\")}\n      class=\"group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===\n      'favorites'\n        ? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'\n        : 'hover:bg-white/5 border-l-2 border-transparent'}\"\n      title=\"Show favorited models\"\n    >\n      <FamilyLogos\n        family=\"favorites\"\n        class={selectedFamily === \"favorites\"\n          ? \"text-amber-400\"\n          : \"text-white/50 group-hover:text-amber-400/70\"}\n      />\n      <span\n        class=\"text-[9px] font-mono mt-0.5 {selectedFamily === 'favorites'\n          ? 'text-amber-400'\n          : 'text-white/40 group-hover:text-white/60'}\">Faves</span\n      >\n    </button>\n  {/if}\n\n  <!-- Recent (only show if has recent models) -->\n  {#if hasRecents}\n    <button\n      type=\"button\"\n      onclick={() => onSelect(\"recents\")}\n      class=\"group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===\n      'recents'\n        ? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'\n        : 'hover:bg-white/5 border-l-2 border-transparent'}\"\n      title=\"Recently launched models\"\n    >\n      <FamilyLogos\n        family=\"recents\"\n        class={selectedFamily === \"recents\"\n          ? \"text-exo-yellow\"\n          : \"text-white/50 group-hover:text-white/70\"}\n      />\n      <span\n        class=\"text-[9px] font-mono mt-0.5 {selectedFamily === 'recents'\n          ? 'text-exo-yellow'\n          : 'text-white/40 group-hover:text-white/60'}\">Recent</span\n      >\n    </button>\n  {/if}\n\n  <!-- HuggingFace Hub -->\n  <button\n    type=\"button\"\n    onclick={() => onSelect(\"huggingface\")}\n    class=\"group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===\n    'huggingface'\n      ? 'bg-orange-500/20 border-l-2 border-orange-400'\n      : 'hover:bg-white/5 border-l-2 border-transparent'}\"\n    title=\"Browse and add models from Hugging Face\"\n  >\n    <FamilyLogos\n      family=\"huggingface\"\n      class={selectedFamily === \"huggingface\"\n        ? \"text-orange-400\"\n        : \"text-white/50 group-hover:text-orange-400/70\"}\n    />\n    <span\n      class=\"text-[9px] font-mono mt-0.5 {selectedFamily === 'huggingface'\n        ? 'text-orange-400'\n        : 'text-white/40 group-hover:text-white/60'}\">Hub</span\n    >\n  </button>\n\n  <div class=\"h-px bg-exo-yellow/10 my-1\"></div>\n\n  <!-- Model families -->\n  {#each families as family}\n    <button\n      type=\"button\"\n      onclick={() => onSelect(family)}\n      class=\"group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===\n      family\n        ? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'\n        : 'hover:bg-white/5 border-l-2 border-transparent'}\"\n      title={getFamilyName(family)}\n    >\n      <FamilyLogos\n        {family}\n        class={selectedFamily === family\n          ? \"text-exo-yellow\"\n          : \"text-white/50 group-hover:text-white/70\"}\n      />\n      <span\n        class=\"text-[9px] font-mono mt-0.5 truncate max-w-full {selectedFamily ===\n        family\n          ? 'text-exo-yellow'\n          : 'text-white/40 group-hover:text-white/60'}\"\n      >\n        {getFamilyName(family)}\n      </span>\n    </button>\n  {/each}\n</div>\n"
  },
  {
    "path": "dashboard/src/lib/components/HeaderNav.svelte",
    "content": "<script lang=\"ts\">\n  import { browser } from \"$app/environment\";\n\n  export let showHome = true;\n  export let onHome: (() => void) | null = null;\n  export let showSidebarToggle = false;\n  export let sidebarVisible = true;\n  export let onToggleSidebar: (() => void) | null = null;\n  export let showMobileMenuToggle = false;\n  export let mobileMenuOpen = false;\n  export let onToggleMobileMenu: (() => void) | null = null;\n  export let showMobileRightToggle = false;\n  export let mobileRightOpen = false;\n  export let onToggleMobileRight: (() => void) | null = null;\n  export let downloadProgress: {\n    count: number;\n    percentage: number;\n  } | null = null;\n\n  function handleHome(): void {\n    if (onHome) {\n      onHome();\n      return;\n    }\n    if (browser) {\n      // Hash router: send to root\n      window.location.hash = \"/\";\n    }\n  }\n\n  function handleToggleSidebar(): void {\n    if (onToggleSidebar) {\n      onToggleSidebar();\n    }\n  }\n\n  function handleToggleMobileMenu(): void {\n    if (onToggleMobileMenu) {\n      onToggleMobileMenu();\n    }\n  }\n\n  function handleToggleMobileRight(): void {\n    if (onToggleMobileRight) {\n      onToggleMobileRight();\n    }\n  }\n</script>\n\n<header\n  class=\"relative z-20 flex items-center justify-center px-4 md:px-6 pt-4 md:pt-8 pb-3 md:pb-4 bg-exo-dark-gray\"\n>\n  <!-- Left: Sidebar Toggle (desktop) or Mobile Sidebar Toggle (mobile) -->\n  <div\n    class=\"absolute left-4 md:left-6 top-1/2 -translate-y-1/2 flex items-center gap-2\"\n  >\n    <!-- Mobile sidebar toggle -->\n    <button\n      onclick={handleToggleMobileMenu}\n      class=\"p-2 rounded border border-exo-light-gray/30 hover:border-exo-yellow/50 hover:bg-exo-medium-gray/30 transition-colors cursor-pointer md:hidden\"\n      title={mobileMenuOpen ? \"Hide sidebar\" : \"Show sidebar\"}\n      aria-label={mobileMenuOpen\n        ? \"Hide conversation sidebar\"\n        : \"Show conversation sidebar\"}\n      aria-pressed={mobileMenuOpen}\n    >\n      <svg\n        fill=\"none\"\n        viewBox=\"0 0 24 24\"\n        stroke=\"currentColor\"\n        stroke-width=\"2\"\n        class=\"w-5 h-5 {mobileMenuOpen\n          ? 'text-exo-yellow'\n          : 'text-exo-light-gray'}\"\n      >\n        {#if mobileMenuOpen}\n          <path\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n            d=\"M11 19l-7-7 7-7m8 14l-7-7 7-7\"\n          ></path>\n        {:else}\n          <path\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n            d=\"M13 5l7 7-7 7M5 5l7 7-7 7\"\n          ></path>\n        {/if}\n      </svg>\n    </button>\n    <!-- Desktop sidebar toggle -->\n    <button\n      onclick={handleToggleSidebar}\n      class=\"p-2 rounded border border-exo-light-gray/30 hover:border-exo-yellow/50 hover:bg-exo-medium-gray/30 transition-colors cursor-pointer hidden md:block\"\n      title={sidebarVisible ? \"Hide sidebar\" : \"Show sidebar\"}\n      aria-label={sidebarVisible\n        ? \"Hide conversation sidebar\"\n        : \"Show conversation sidebar\"}\n      aria-pressed={sidebarVisible}\n    >\n      <svg\n        fill=\"none\"\n        viewBox=\"0 0 24 24\"\n        stroke=\"currentColor\"\n        stroke-width=\"2\"\n        class=\"w-5 h-5 {sidebarVisible\n          ? 'text-exo-yellow'\n          : 'text-exo-light-gray'}\"\n      >\n        {#if sidebarVisible}\n          <path\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n            d=\"M11 19l-7-7 7-7m8 14l-7-7 7-7\"\n          ></path>\n        {:else}\n          <path\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n            d=\"M13 5l7 7-7 7M5 5l7 7-7 7\"\n          ></path>\n        {/if}\n      </svg>\n    </button>\n  </div>\n\n  <!-- Center: Logo (clickable to go home) -->\n  <button\n    onclick={handleHome}\n    class=\"bg-transparent border-none outline-none focus:outline-none transition-opacity duration-200 hover:opacity-90 {showHome\n      ? 'cursor-pointer'\n      : 'cursor-default'}\"\n    title={showHome ? \"Go to home\" : \"\"}\n    disabled={!showHome}\n  >\n    <img\n      src=\"/exo-logo.png\"\n      alt=\"EXO\"\n      class=\"h-12 md:h-18 drop-shadow-[0_0_4px_rgba(255,215,0,0.3)]\"\n    />\n  </button>\n\n  <!-- Right: Home + Downloads + Mobile Right Toggle -->\n  <nav\n    class=\"absolute right-4 md:right-6 top-1/2 -translate-y-1/2 flex items-center gap-2 md:gap-4\"\n    aria-label=\"Main navigation\"\n  >\n    <!-- Mobile right sidebar toggle (instances/models) - only show when not in chat mode -->\n    {#if showMobileRightToggle}\n      <button\n        onclick={handleToggleMobileRight}\n        class=\"p-2 rounded border border-exo-light-gray/30 hover:border-exo-yellow/50 hover:bg-exo-medium-gray/30 transition-colors cursor-pointer md:hidden\"\n        title={mobileRightOpen ? \"Hide instances\" : \"Show instances\"}\n        aria-label={mobileRightOpen\n          ? \"Hide instances panel\"\n          : \"Show instances panel\"}\n        aria-pressed={mobileRightOpen}\n      >\n        <svg\n          fill=\"none\"\n          viewBox=\"0 0 24 24\"\n          stroke=\"currentColor\"\n          stroke-width=\"2\"\n          class=\"w-5 h-5 {mobileRightOpen\n            ? 'text-exo-yellow'\n            : 'text-exo-light-gray'}\"\n        >\n          {#if mobileRightOpen}\n            <path\n              stroke-linecap=\"round\"\n              stroke-linejoin=\"round\"\n              d=\"M13 5l7 7-7 7M5 5l7 7-7 7\"\n            ></path>\n          {:else}\n            <path\n              stroke-linecap=\"round\"\n              stroke-linejoin=\"round\"\n              d=\"M11 19l-7-7 7-7m8 14l-7-7 7-7\"\n            ></path>\n          {/if}\n        </svg>\n      </button>\n    {/if}\n    {#if showHome}\n      <button\n        onclick={handleHome}\n        class=\"flex text-sm text-white/70 hover:text-exo-yellow transition-colors tracking-wider uppercase items-center gap-2 cursor-pointer\"\n        title=\"Back to topology view\"\n      >\n        <svg\n          class=\"w-4 h-4\"\n          fill=\"none\"\n          viewBox=\"0 0 24 24\"\n          stroke=\"currentColor\"\n        >\n          <path\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n            stroke-width=\"2\"\n            d=\"M3 12l2-2m0 0l7-7 7 7M5 10v10a1 1 0 001 1h3m10-11l2 2m-2-2v10a1 1 0 01-1 1h-3m-6 0a1 1 0 001-1v-4a1 1 0 011-1h2a1 1 0 011 1v4a1 1 0 001 1m-6 0h6\"\n          />\n        </svg>\n        <span class=\"hidden sm:inline\">Home</span>\n      </button>\n    {/if}\n    <a\n      href=\"/#/downloads\"\n      class=\"text-xs md:text-sm text-white/70 hover:text-exo-yellow transition-colors tracking-wider uppercase flex items-center gap-1.5 md:gap-2 cursor-pointer\"\n      title=\"View downloads overview\"\n    >\n      {#if downloadProgress}\n        <!-- Compact download progress indicator -->\n        <div class=\"relative w-4 h-4 flex-shrink-0\">\n          <svg class=\"w-4 h-4 -rotate-90\" viewBox=\"0 0 20 20\">\n            <circle\n              cx=\"10\"\n              cy=\"10\"\n              r=\"8\"\n              fill=\"none\"\n              stroke=\"currentColor\"\n              stroke-width=\"2\"\n              opacity=\"0.2\"\n            />\n            <circle\n              cx=\"10\"\n              cy=\"10\"\n              r=\"8\"\n              fill=\"none\"\n              stroke=\"currentColor\"\n              stroke-width=\"2\"\n              stroke-dasharray={2 * Math.PI * 8}\n              stroke-dashoffset={2 *\n                Math.PI *\n                8 *\n                (1 - downloadProgress.percentage / 100)}\n              class=\"text-blue-400 transition-all duration-300\"\n            />\n          </svg>\n          <div\n            class=\"absolute inset-0 flex items-center justify-center text-[6px] font-mono text-blue-400\"\n          >\n            {downloadProgress.count}\n          </div>\n        </div>\n      {:else}\n        <svg\n          class=\"w-4 h-4\"\n          viewBox=\"0 0 24 24\"\n          fill=\"none\"\n          stroke=\"currentColor\"\n          stroke-width=\"2\"\n          stroke-linecap=\"round\"\n          stroke-linejoin=\"round\"\n        >\n          <path d=\"M12 3v12\" />\n          <path d=\"M7 12l5 5 5-5\" />\n          <path d=\"M5 21h14\" />\n        </svg>\n      {/if}\n      <span class=\"hidden sm:inline\">Downloads</span>\n    </a>\n  </nav>\n</header>\n"
  },
  {
    "path": "dashboard/src/lib/components/HuggingFaceResultItem.svelte",
    "content": "<script lang=\"ts\">\n  interface HuggingFaceModel {\n    id: string;\n    author: string;\n    downloads: number;\n    likes: number;\n    last_modified: string;\n    tags: string[];\n  }\n\n  type HuggingFaceResultItemProps = {\n    model: HuggingFaceModel;\n    isAdded: boolean;\n    isAdding: boolean;\n    onAdd: () => void;\n    onSelect: () => void;\n    downloadedOnNodes?: string[];\n  };\n\n  let {\n    model,\n    isAdded,\n    isAdding,\n    onAdd,\n    onSelect,\n    downloadedOnNodes = [],\n  }: HuggingFaceResultItemProps = $props();\n\n  function formatNumber(num: number | undefined): string {\n    if (num == null) return \"0\";\n    if (num >= 1000000) {\n      return `${(num / 1000000).toFixed(1)}M`;\n    } else if (num >= 1000) {\n      return `${(num / 1000).toFixed(1)}k`;\n    }\n    return num.toString();\n  }\n\n  // Show short name for mlx-community models, full ID for everything else\n  const modelName = $derived(\n    model.author === \"mlx-community\"\n      ? model.id.split(\"/\").pop() || model.id\n      : model.id,\n  );\n</script>\n\n<div\n  class=\"flex items-center justify-between gap-3 px-3 py-2.5 hover:bg-white/5 transition-colors border-b border-white/5 last:border-b-0\"\n>\n  <div class=\"flex-1 min-w-0\">\n    <div class=\"flex items-center gap-2\">\n      <span class=\"text-sm font-mono text-white truncate\" title={model.id}\n        >{modelName}</span\n      >\n      {#if downloadedOnNodes.length > 0}\n        <span\n          class=\"flex-shrink-0\"\n          title={`Downloaded on ${downloadedOnNodes.join(\", \")}`}\n        >\n          <svg\n            class=\"w-4 h-4\"\n            viewBox=\"0 0 24 24\"\n            fill=\"none\"\n            stroke=\"currentColor\"\n            stroke-width=\"2\"\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n          >\n            <path\n              class=\"text-white/40\"\n              d=\"M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z\"\n            />\n            <path class=\"text-green-400\" d=\"m9 13 2 2 4-4\" />\n          </svg>\n        </span>\n      {/if}\n      {#if isAdded}\n        <span\n          class=\"px-1.5 py-0.5 text-[10px] font-mono bg-green-500/20 text-green-400 rounded\"\n          >Added</span\n        >\n      {/if}\n    </div>\n    <div class=\"flex items-center gap-3 mt-0.5 text-xs text-white/40\">\n      <span class=\"truncate\">{model.author}</span>\n      <span\n        class=\"flex items-center gap-1 shrink-0\"\n        title=\"Downloads in the last 30 days\"\n      >\n        <svg\n          class=\"w-3 h-3\"\n          fill=\"none\"\n          stroke=\"currentColor\"\n          viewBox=\"0 0 24 24\"\n        >\n          <path\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n            stroke-width=\"2\"\n            d=\"M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4\"\n          />\n        </svg>\n        {formatNumber(model.downloads)}\n      </span>\n      <span\n        class=\"flex items-center gap-1 shrink-0\"\n        title=\"Community likes on Hugging Face\"\n      >\n        <svg\n          class=\"w-3 h-3\"\n          fill=\"none\"\n          stroke=\"currentColor\"\n          viewBox=\"0 0 24 24\"\n        >\n          <path\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n            stroke-width=\"2\"\n            d=\"M4.318 6.318a4.5 4.5 0 000 6.364L12 20.364l7.682-7.682a4.5 4.5 0 00-6.364-6.364L12 7.636l-1.318-1.318a4.5 4.5 0 00-6.364 0z\"\n          />\n        </svg>\n        {formatNumber(model.likes)}\n      </span>\n    </div>\n  </div>\n\n  <div class=\"flex items-center gap-2 shrink-0\">\n    {#if isAdded}\n      <button\n        type=\"button\"\n        onclick={onSelect}\n        class=\"px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-exo-yellow/10 text-exo-yellow border border-exo-yellow/30 hover:bg-exo-yellow/20 transition-colors rounded cursor-pointer\"\n      >\n        Select\n      </button>\n    {:else}\n      <button\n        type=\"button\"\n        onclick={onAdd}\n        disabled={isAdding}\n        class=\"px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-orange-500/10 text-orange-400 border border-orange-400/30 hover:bg-orange-500/20 transition-colors rounded cursor-pointer disabled:opacity-50 disabled:cursor-not-allowed\"\n      >\n        {#if isAdding}\n          <span class=\"flex items-center gap-1.5\">\n            <span\n              class=\"w-3 h-3 border-2 border-orange-400 border-t-transparent rounded-full animate-spin\"\n            ></span>\n            Adding...\n          </span>\n        {:else}\n          + Add\n        {/if}\n      </button>\n    {/if}\n  </div>\n</div>\n"
  },
  {
    "path": "dashboard/src/lib/components/ImageLightbox.svelte",
    "content": "<script lang=\"ts\">\n  import { fade, fly } from \"svelte/transition\";\n  import { cubicOut } from \"svelte/easing\";\n\n  interface Props {\n    src: string | null;\n    onclose: () => void;\n  }\n\n  let { src, onclose }: Props = $props();\n\n  function handleKeydown(e: KeyboardEvent) {\n    if (e.key === \"Escape\") {\n      onclose();\n    }\n  }\n\n  function extensionFromSrc(dataSrc: string): string {\n    const match = dataSrc.match(/^data:image\\/(\\w+)/);\n    if (match) return match[1] === \"jpeg\" ? \"jpg\" : match[1];\n    const urlMatch = dataSrc.match(/\\.(\\w+)(?:\\?|$)/);\n    if (urlMatch) return urlMatch[1];\n    return \"png\";\n  }\n\n  function handleDownload(e: MouseEvent) {\n    e.stopPropagation();\n    if (!src) return;\n    const link = document.createElement(\"a\");\n    link.href = src;\n    link.download = `image-${Date.now()}.${extensionFromSrc(src)}`;\n    link.click();\n  }\n\n  function handleClose(e: MouseEvent) {\n    e.stopPropagation();\n    onclose();\n  }\n</script>\n\n<svelte:window onkeydown={src ? handleKeydown : undefined} />\n\n{#if src}\n  <div\n    class=\"fixed inset-0 z-50 bg-black/90 backdrop-blur-sm flex items-center justify-center\"\n    transition:fade={{ duration: 200 }}\n    onclick={onclose}\n    role=\"presentation\"\n    onintrostart={() => (document.body.style.overflow = \"hidden\")}\n    onoutroend={() => (document.body.style.overflow = \"\")}\n  >\n    <div class=\"absolute top-4 right-4 flex gap-2 z-10\">\n      <button\n        type=\"button\"\n        class=\"p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer transition-colors\"\n        onclick={handleDownload}\n        title=\"Download image\"\n      >\n        <svg\n          class=\"w-5 h-5\"\n          fill=\"none\"\n          viewBox=\"0 0 24 24\"\n          stroke=\"currentColor\"\n          stroke-width=\"2\"\n        >\n          <path\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n            d=\"M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4\"\n          />\n        </svg>\n      </button>\n      <button\n        type=\"button\"\n        class=\"p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer transition-colors\"\n        onclick={handleClose}\n        title=\"Close\"\n      >\n        <svg class=\"w-5 h-5\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n          <path\n            d=\"M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z\"\n          />\n        </svg>\n      </button>\n    </div>\n\n    <!-- svelte-ignore a11y_no_noninteractive_element_interactions, a11y_click_events_have_key_events -->\n    <img\n      {src}\n      alt=\"\"\n      class=\"max-w-[90vw] max-h-[90vh] object-contain rounded-lg shadow-2xl\"\n      transition:fly={{ y: 20, duration: 300, easing: cubicOut }}\n      onclick={(e) => e.stopPropagation()}\n    />\n  </div>\n{/if}\n"
  },
  {
    "path": "dashboard/src/lib/components/ImageParamsPanel.svelte",
    "content": "<script lang=\"ts\">\n  import {\n    imageGenerationParams,\n    setImageGenerationParams,\n    resetImageGenerationParams,\n    type ImageGenerationParams,\n  } from \"$lib/stores/app.svelte\";\n\n  interface Props {\n    isEditMode?: boolean;\n  }\n\n  let { isEditMode = false }: Props = $props();\n\n  let showAdvanced = $state(false);\n\n  // Custom dropdown state\n  let isSizeDropdownOpen = $state(false);\n  let isQualityDropdownOpen = $state(false);\n  let sizeButtonRef: HTMLButtonElement | undefined = $state();\n  let qualityButtonRef: HTMLButtonElement | undefined = $state();\n\n  const sizeDropdownPosition = $derived(() => {\n    if (!sizeButtonRef || !isSizeDropdownOpen)\n      return { top: 0, left: 0, width: 0 };\n    const rect = sizeButtonRef.getBoundingClientRect();\n    return { top: rect.top, left: rect.left, width: rect.width };\n  });\n\n  const qualityDropdownPosition = $derived(() => {\n    if (!qualityButtonRef || !isQualityDropdownOpen)\n      return { top: 0, left: 0, width: 0 };\n    const rect = qualityButtonRef.getBoundingClientRect();\n    return { top: rect.top, left: rect.left, width: rect.width };\n  });\n\n  const params = $derived(imageGenerationParams());\n\n  const inputFidelityOptions: ImageGenerationParams[\"inputFidelity\"][] = [\n    \"low\",\n    \"high\",\n  ];\n\n  const outputFormatOptions: ImageGenerationParams[\"outputFormat\"][] = [\n    \"png\",\n    \"jpeg\",\n  ];\n\n  function handleInputFidelityChange(\n    value: ImageGenerationParams[\"inputFidelity\"],\n  ) {\n    setImageGenerationParams({ inputFidelity: value });\n  }\n\n  function handleOutputFormatChange(\n    value: ImageGenerationParams[\"outputFormat\"],\n  ) {\n    setImageGenerationParams({ outputFormat: value });\n  }\n\n  const sizeOptions: ImageGenerationParams[\"size\"][] = [\n    \"auto\",\n    \"512x512\",\n    \"768x768\",\n    \"1024x1024\",\n    \"1024x768\",\n    \"768x1024\",\n    \"1024x1536\",\n    \"1536x1024\",\n  ];\n\n  const qualityOptions: ImageGenerationParams[\"quality\"][] = [\n    \"low\",\n    \"medium\",\n    \"high\",\n  ];\n\n  function selectSize(value: ImageGenerationParams[\"size\"]) {\n    setImageGenerationParams({ size: value });\n    isSizeDropdownOpen = false;\n  }\n\n  function selectQuality(value: ImageGenerationParams[\"quality\"]) {\n    setImageGenerationParams({ quality: value });\n    isQualityDropdownOpen = false;\n  }\n\n  function handleSeedChange(event: Event) {\n    const input = event.target as HTMLInputElement;\n    const value = input.value.trim();\n    if (value === \"\") {\n      setImageGenerationParams({ seed: null });\n    } else {\n      const num = parseInt(value, 10);\n      if (!isNaN(num) && num >= 0) {\n        setImageGenerationParams({ seed: num });\n      }\n    }\n  }\n\n  function handleStepsChange(event: Event) {\n    const value = parseInt((event.target as HTMLInputElement).value, 10);\n    setImageGenerationParams({ numInferenceSteps: value });\n  }\n\n  function handleGuidanceChange(event: Event) {\n    const value = parseFloat((event.target as HTMLInputElement).value);\n    setImageGenerationParams({ guidance: value });\n  }\n\n  function handleNegativePromptChange(event: Event) {\n    const value = (event.target as HTMLTextAreaElement).value;\n    setImageGenerationParams({ negativePrompt: value || null });\n  }\n\n  function handleNumImagesChange(event: Event) {\n    const input = event.target as HTMLInputElement;\n    const value = input.value.trim();\n    if (value === \"\") {\n      setImageGenerationParams({ numImages: 1 });\n    } else {\n      const num = parseInt(value, 10);\n      if (!isNaN(num) && num >= 1) {\n        setImageGenerationParams({ numImages: num });\n      }\n    }\n  }\n\n  function handleStreamChange(enabled: boolean) {\n    setImageGenerationParams({ stream: enabled });\n  }\n\n  function handlePartialImagesChange(event: Event) {\n    const input = event.target as HTMLInputElement;\n    const value = input.value.trim();\n    if (value === \"\") {\n      setImageGenerationParams({ partialImages: 0 });\n    } else {\n      const num = parseInt(value, 10);\n      if (!isNaN(num) && num >= 0) {\n        setImageGenerationParams({ partialImages: num });\n      }\n    }\n  }\n\n  function clearSteps() {\n    setImageGenerationParams({ numInferenceSteps: null });\n  }\n\n  function clearGuidance() {\n    setImageGenerationParams({ guidance: null });\n  }\n\n  function handleNumSyncStepsChange(event: Event) {\n    const value = parseInt((event.target as HTMLInputElement).value, 10);\n    setImageGenerationParams({ numSyncSteps: value });\n  }\n\n  function clearNumSyncSteps() {\n    setImageGenerationParams({ numSyncSteps: null });\n  }\n\n  function handleReset() {\n    resetImageGenerationParams();\n    showAdvanced = false;\n  }\n\n  const hasAdvancedParams = $derived(\n    params.seed !== null ||\n      params.numInferenceSteps !== null ||\n      params.guidance !== null ||\n      (params.negativePrompt !== null && params.negativePrompt.trim() !== \"\") ||\n      params.numSyncSteps !== null,\n  );\n</script>\n\n<div class=\"border-b border-exo-medium-gray/30 px-3 py-2\">\n  <!-- Basic params row -->\n  <div class=\"flex items-center gap-3 flex-wrap\">\n    <!-- Size -->\n    <div class=\"flex items-center gap-1.5\">\n      <span class=\"text-xs text-exo-light-gray uppercase tracking-wider\"\n        >SIZE:</span\n      >\n      <div class=\"relative\">\n        <button\n          bind:this={sizeButtonRef}\n          type=\"button\"\n          onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}\n          class=\"bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen\n            ? 'border-exo-yellow/70'\n            : ''}\"\n        >\n          {params.size.toUpperCase()}\n        </button>\n        <div\n          class=\"absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen\n            ? 'rotate-180'\n            : ''}\"\n        >\n          <svg\n            class=\"w-3 h-3 text-exo-yellow/60\"\n            fill=\"none\"\n            viewBox=\"0 0 24 24\"\n            stroke=\"currentColor\"\n          >\n            <path\n              stroke-linecap=\"round\"\n              stroke-linejoin=\"round\"\n              stroke-width=\"2\"\n              d=\"M19 9l-7 7-7-7\"\n            />\n          </svg>\n        </div>\n      </div>\n\n      {#if isSizeDropdownOpen}\n        <!-- Backdrop to close dropdown -->\n        <button\n          type=\"button\"\n          class=\"fixed inset-0 z-[9998] cursor-default\"\n          onclick={() => (isSizeDropdownOpen = false)}\n          aria-label=\"Close dropdown\"\n        ></button>\n\n        <!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->\n        <div\n          class=\"fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max\"\n          style=\"bottom: calc(100vh - {sizeDropdownPosition()\n            .top}px + 4px); left: {sizeDropdownPosition().left}px;\"\n        >\n          <div class=\"py-1\">\n            {#each sizeOptions as size}\n              <button\n                type=\"button\"\n                onclick={() => selectSize(size)}\n                class=\"w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===\n                size\n                  ? 'bg-transparent text-exo-yellow'\n                  : 'text-exo-light-gray hover:text-exo-yellow'}\"\n              >\n                {#if params.size === size}\n                  <svg\n                    class=\"w-3 h-3 flex-shrink-0\"\n                    fill=\"currentColor\"\n                    viewBox=\"0 0 20 20\"\n                  >\n                    <path\n                      fill-rule=\"evenodd\"\n                      d=\"M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z\"\n                      clip-rule=\"evenodd\"\n                    />\n                  </svg>\n                {:else}\n                  <span class=\"w-3\"></span>\n                {/if}\n                <span>{size.toUpperCase()}</span>\n              </button>\n            {/each}\n          </div>\n        </div>\n      {/if}\n    </div>\n\n    <!-- Quality -->\n    <div class=\"flex items-center gap-1.5\">\n      <span class=\"text-xs text-exo-light-gray uppercase tracking-wider\"\n        >QUALITY:</span\n      >\n      <div class=\"relative\">\n        <button\n          bind:this={qualityButtonRef}\n          type=\"button\"\n          onclick={() => (isQualityDropdownOpen = !isQualityDropdownOpen)}\n          class=\"bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isQualityDropdownOpen\n            ? 'border-exo-yellow/70'\n            : ''}\"\n        >\n          {params.quality.toUpperCase()}\n        </button>\n        <div\n          class=\"absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isQualityDropdownOpen\n            ? 'rotate-180'\n            : ''}\"\n        >\n          <svg\n            class=\"w-3 h-3 text-exo-yellow/60\"\n            fill=\"none\"\n            viewBox=\"0 0 24 24\"\n            stroke=\"currentColor\"\n          >\n            <path\n              stroke-linecap=\"round\"\n              stroke-linejoin=\"round\"\n              stroke-width=\"2\"\n              d=\"M19 9l-7 7-7-7\"\n            />\n          </svg>\n        </div>\n      </div>\n\n      {#if isQualityDropdownOpen}\n        <!-- Backdrop to close dropdown -->\n        <button\n          type=\"button\"\n          class=\"fixed inset-0 z-[9998] cursor-default\"\n          onclick={() => (isQualityDropdownOpen = false)}\n          aria-label=\"Close dropdown\"\n        ></button>\n\n        <!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->\n        <div\n          class=\"fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max\"\n          style=\"bottom: calc(100vh - {qualityDropdownPosition()\n            .top}px + 4px); left: {qualityDropdownPosition().left}px;\"\n        >\n          <div class=\"py-1\">\n            {#each qualityOptions as quality}\n              <button\n                type=\"button\"\n                onclick={() => selectQuality(quality)}\n                class=\"w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.quality ===\n                quality\n                  ? 'bg-transparent text-exo-yellow'\n                  : 'text-exo-light-gray hover:text-exo-yellow'}\"\n              >\n                {#if params.quality === quality}\n                  <svg\n                    class=\"w-3 h-3 flex-shrink-0\"\n                    fill=\"currentColor\"\n                    viewBox=\"0 0 20 20\"\n                  >\n                    <path\n                      fill-rule=\"evenodd\"\n                      d=\"M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z\"\n                      clip-rule=\"evenodd\"\n                    />\n                  </svg>\n                {:else}\n                  <span class=\"w-3\"></span>\n                {/if}\n                <span>{quality.toUpperCase()}</span>\n              </button>\n            {/each}\n          </div>\n        </div>\n      {/if}\n    </div>\n\n    <!-- Format -->\n    <div class=\"flex items-center gap-1.5\">\n      <span class=\"text-xs text-exo-light-gray uppercase tracking-wider\"\n        >FORMAT:</span\n      >\n      <div class=\"flex rounded overflow-hidden border border-exo-yellow/30\">\n        {#each outputFormatOptions as format}\n          <button\n            type=\"button\"\n            onclick={() => handleOutputFormatChange(format)}\n            class=\"px-2 py-1 text-xs font-mono uppercase transition-all duration-200 cursor-pointer {params.outputFormat ===\n            format\n              ? 'bg-exo-yellow text-exo-black'\n              : 'bg-exo-medium-gray/50 text-exo-light-gray hover:text-exo-yellow'}\"\n          >\n            {format}\n          </button>\n        {/each}\n      </div>\n    </div>\n\n    <!-- Number of Images (not in edit mode) -->\n    {#if !isEditMode}\n      <div class=\"flex items-center gap-1.5\">\n        <span class=\"text-xs text-exo-light-gray uppercase tracking-wider\"\n          >IMAGES:</span\n        >\n        <input\n          type=\"number\"\n          min=\"1\"\n          value={params.numImages}\n          oninput={handleNumImagesChange}\n          class=\"w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70\"\n        />\n      </div>\n    {/if}\n\n    <!-- Stream toggle -->\n    <div class=\"flex items-center gap-1.5\">\n      <span class=\"text-xs text-exo-light-gray uppercase tracking-wider\"\n        >STREAM:</span\n      >\n      <button\n        type=\"button\"\n        onclick={() => handleStreamChange(!params.stream)}\n        class=\"w-8 h-4 rounded-full transition-all duration-200 cursor-pointer relative {params.stream\n          ? 'bg-exo-yellow'\n          : 'bg-exo-medium-gray/50 border border-exo-yellow/30'}\"\n        title={params.stream ? \"Streaming enabled\" : \"Streaming disabled\"}\n      >\n        <div\n          class=\"absolute top-0.5 w-3 h-3 rounded-full transition-all duration-200 {params.stream\n            ? 'right-0.5 bg-exo-black'\n            : 'left-0.5 bg-exo-light-gray'}\"\n        ></div>\n      </button>\n    </div>\n\n    <!-- Partial Images (only when streaming) -->\n    {#if params.stream}\n      <div class=\"flex items-center gap-1.5\">\n        <span class=\"text-xs text-exo-light-gray uppercase tracking-wider\"\n          >PARTIALS:</span\n        >\n        <input\n          type=\"number\"\n          min=\"0\"\n          value={params.partialImages}\n          oninput={handlePartialImagesChange}\n          class=\"w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70\"\n        />\n      </div>\n    {/if}\n\n    <!-- Input Fidelity (edit mode only) -->\n    {#if isEditMode}\n      <div class=\"flex items-center gap-1.5\">\n        <span class=\"text-xs text-exo-light-gray uppercase tracking-wider\"\n          >FIDELITY:</span\n        >\n        <div class=\"flex rounded overflow-hidden border border-exo-yellow/30\">\n          {#each inputFidelityOptions as fidelity}\n            <button\n              type=\"button\"\n              onclick={() => handleInputFidelityChange(fidelity)}\n              class=\"px-2 py-1 text-xs font-mono uppercase transition-all duration-200 cursor-pointer {params.inputFidelity ===\n              fidelity\n                ? 'bg-exo-yellow text-exo-black'\n                : 'bg-exo-medium-gray/50 text-exo-light-gray hover:text-exo-yellow'}\"\n              title={fidelity === \"low\"\n                ? \"More creative variation\"\n                : \"Closer to original\"}\n            >\n              {fidelity}\n            </button>\n          {/each}\n        </div>\n      </div>\n    {/if}\n\n    <!-- Spacer -->\n    <div class=\"flex-1\"></div>\n\n    <!-- Advanced toggle -->\n    <button\n      type=\"button\"\n      onclick={() => (showAdvanced = !showAdvanced)}\n      class=\"flex items-center gap-1 text-xs font-mono tracking-wider uppercase transition-colors duration-200 {showAdvanced ||\n      hasAdvancedParams\n        ? 'text-exo-yellow'\n        : 'text-exo-light-gray hover:text-exo-yellow'}\"\n    >\n      <span>ADVANCED</span>\n      <svg\n        class=\"w-3 h-3 transition-transform duration-200 {showAdvanced\n          ? 'rotate-180'\n          : ''}\"\n        fill=\"none\"\n        viewBox=\"0 0 24 24\"\n        stroke=\"currentColor\"\n      >\n        <path\n          stroke-linecap=\"round\"\n          stroke-linejoin=\"round\"\n          stroke-width=\"2\"\n          d=\"M19 9l-7 7-7-7\"\n        />\n      </svg>\n      {#if hasAdvancedParams && !showAdvanced}\n        <span class=\"w-1.5 h-1.5 rounded-full bg-exo-yellow\"></span>\n      {/if}\n    </button>\n  </div>\n\n  <!-- Advanced params section -->\n  {#if showAdvanced}\n    <div class=\"mt-3 pt-3 border-t border-exo-medium-gray/20 space-y-3\">\n      <!-- Row 1: Seed and Steps -->\n      <div class=\"flex items-center gap-4 flex-wrap\">\n        <!-- Seed -->\n        <div class=\"flex items-center gap-1.5\">\n          <span class=\"text-xs text-exo-light-gray uppercase tracking-wider\"\n            >SEED:</span\n          >\n          <input\n            type=\"number\"\n            min=\"0\"\n            value={params.seed ?? \"\"}\n            oninput={handleSeedChange}\n            placeholder=\"Random\"\n            class=\"w-24 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow placeholder:text-exo-light-gray/50 transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70\"\n          />\n        </div>\n\n        <!-- Steps Slider -->\n        <div class=\"flex items-center gap-1.5 flex-1 min-w-[200px]\">\n          <span\n            class=\"text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap\"\n            >STEPS:</span\n          >\n          <div class=\"flex items-center gap-2 flex-1\">\n            <input\n              type=\"range\"\n              min=\"1\"\n              max=\"100\"\n              value={params.numInferenceSteps ?? 50}\n              oninput={handleStepsChange}\n              class=\"flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow\"\n            />\n            <span class=\"text-xs font-mono text-exo-yellow w-8 text-right\">\n              {params.numInferenceSteps ?? \"--\"}\n            </span>\n            {#if params.numInferenceSteps !== null}\n              <button\n                type=\"button\"\n                onclick={clearSteps}\n                class=\"text-exo-light-gray hover:text-exo-yellow transition-colors\"\n                title=\"Clear\"\n              >\n                <svg\n                  class=\"w-3 h-3\"\n                  fill=\"none\"\n                  viewBox=\"0 0 24 24\"\n                  stroke=\"currentColor\"\n                >\n                  <path\n                    stroke-linecap=\"round\"\n                    stroke-linejoin=\"round\"\n                    stroke-width=\"2\"\n                    d=\"M6 18L18 6M6 6l12 12\"\n                  />\n                </svg>\n              </button>\n            {/if}\n          </div>\n        </div>\n      </div>\n\n      <!-- Row 2: Guidance -->\n      <div class=\"flex items-center gap-1.5\">\n        <span\n          class=\"text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap\"\n          >GUIDANCE:</span\n        >\n        <div class=\"flex items-center gap-2 flex-1 max-w-xs\">\n          <input\n            type=\"range\"\n            min=\"1\"\n            max=\"20\"\n            step=\"0.5\"\n            value={params.guidance ?? 7.5}\n            oninput={handleGuidanceChange}\n            class=\"flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow\"\n          />\n          <span class=\"text-xs font-mono text-exo-yellow w-8 text-right\">\n            {params.guidance !== null ? params.guidance.toFixed(1) : \"--\"}\n          </span>\n          {#if params.guidance !== null}\n            <button\n              type=\"button\"\n              onclick={clearGuidance}\n              class=\"text-exo-light-gray hover:text-exo-yellow transition-colors\"\n              title=\"Clear\"\n            >\n              <svg\n                class=\"w-3 h-3\"\n                fill=\"none\"\n                viewBox=\"0 0 24 24\"\n                stroke=\"currentColor\"\n              >\n                <path\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                  stroke-width=\"2\"\n                  d=\"M6 18L18 6M6 6l12 12\"\n                />\n              </svg>\n            </button>\n          {/if}\n        </div>\n      </div>\n\n      <!-- Row 3: Sync Steps -->\n      <div class=\"flex items-center gap-1.5\">\n        <span\n          class=\"text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap\"\n          >SYNC STEPS:</span\n        >\n        <div class=\"flex items-center gap-2 flex-1 max-w-xs\">\n          <input\n            type=\"range\"\n            min=\"1\"\n            max=\"100\"\n            value={params.numSyncSteps ?? 1}\n            oninput={handleNumSyncStepsChange}\n            class=\"flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow\"\n          />\n          <span class=\"text-xs font-mono text-exo-yellow w-8 text-right\">\n            {params.numSyncSteps ?? \"--\"}\n          </span>\n          {#if params.numSyncSteps !== null}\n            <button\n              type=\"button\"\n              onclick={clearNumSyncSteps}\n              class=\"text-exo-light-gray hover:text-exo-yellow transition-colors\"\n              title=\"Clear\"\n            >\n              <svg\n                class=\"w-3 h-3\"\n                fill=\"none\"\n                viewBox=\"0 0 24 24\"\n                stroke=\"currentColor\"\n              >\n                <path\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                  stroke-width=\"2\"\n                  d=\"M6 18L18 6M6 6l12 12\"\n                />\n              </svg>\n            </button>\n          {/if}\n        </div>\n      </div>\n\n      <!-- Row 4: Negative Prompt -->\n      <div class=\"flex flex-col gap-1.5\">\n        <span class=\"text-xs text-exo-light-gray uppercase tracking-wider\"\n          >NEGATIVE PROMPT:</span\n        >\n        <textarea\n          value={params.negativePrompt ?? \"\"}\n          oninput={handleNegativePromptChange}\n          placeholder=\"Things to avoid in the image...\"\n          rows={2}\n          class=\"w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1.5 text-xs font-mono text-exo-yellow placeholder:text-exo-light-gray/50 resize-none transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70\"\n        ></textarea>\n      </div>\n\n      <!-- Reset Button -->\n      <div class=\"flex justify-end pt-1\">\n        <button\n          type=\"button\"\n          onclick={handleReset}\n          class=\"text-xs font-mono tracking-wider uppercase text-exo-light-gray hover:text-exo-yellow transition-colors duration-200\"\n        >\n          RESET TO DEFAULTS\n        </button>\n      </div>\n    </div>\n  {/if}\n</div>\n\n<style>\n  /* Custom range slider styling */\n  input[type=\"range\"]::-webkit-slider-thumb {\n    -webkit-appearance: none;\n    appearance: none;\n    width: 12px;\n    height: 12px;\n    border-radius: 50%;\n    background: #ffd700;\n    cursor: pointer;\n    border: none;\n  }\n\n  input[type=\"range\"]::-moz-range-thumb {\n    width: 12px;\n    height: 12px;\n    border-radius: 50%;\n    background: #ffd700;\n    cursor: pointer;\n    border: none;\n  }\n\n  /* Hide number input spinners */\n  input[type=\"number\"]::-webkit-inner-spin-button,\n  input[type=\"number\"]::-webkit-outer-spin-button {\n    -webkit-appearance: none;\n    margin: 0;\n  }\n\n  input[type=\"number\"] {\n    -moz-appearance: textfield;\n  }\n</style>\n"
  },
  {
    "path": "dashboard/src/lib/components/MarkdownContent.svelte",
    "content": "<script lang=\"ts\">\n  import { marked } from \"marked\";\n  import hljs from \"highlight.js\";\n  import katex from \"katex\";\n  import \"katex/dist/katex.min.css\";\n  import { browser } from \"$app/environment\";\n\n  interface Props {\n    content: string;\n    class?: string;\n  }\n\n  let { content, class: className = \"\" }: Props = $props();\n\n  let containerRef = $state<HTMLDivElement>();\n  let processedHtml = $state(\"\");\n\n  // Configure marked with syntax highlighting\n  marked.setOptions({\n    gfm: true,\n    breaks: true,\n  });\n\n  // Custom renderer for code blocks\n  const renderer = new marked.Renderer();\n\n  renderer.code = function ({ text, lang }: { text: string; lang?: string }) {\n    const language = lang && hljs.getLanguage(lang) ? lang : \"plaintext\";\n    const highlighted = hljs.highlight(text, { language }).value;\n    const codeId = `code-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`;\n\n    return `\n\t\t\t<div class=\"code-block-wrapper\">\n\t\t\t\t<div class=\"code-block-header\">\n\t\t\t\t\t<span class=\"code-language\">${language}</span>\n\t\t\t\t\t<button type=\"button\" class=\"copy-code-btn\" data-code=\"${encodeURIComponent(text)}\" title=\"Copy code\">\n\t\t\t\t\t\t<svg width=\"16\" height=\"16\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\">\n\t\t\t\t\t\t\t<rect width=\"14\" height=\"14\" x=\"8\" y=\"8\" rx=\"2\" ry=\"2\"/>\n\t\t\t\t\t\t\t<path d=\"M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2\"/>\n\t\t\t\t\t\t</svg>\n\t\t\t\t\t</button>\n\t\t\t\t</div>\n\t\t\t\t<pre><code class=\"hljs language-${language}\" data-code-id=\"${codeId}\">${highlighted}</code></pre>\n\t\t\t</div>\n\t\t`;\n  };\n\n  // Inline code\n  renderer.codespan = function ({ text }: { text: string }) {\n    return `<code class=\"inline-code\">${text}</code>`;\n  };\n\n  marked.use({ renderer });\n\n  /**\n   * Unescape HTML entities that marked may have escaped\n   */\n  function unescapeHtmlEntities(text: string): string {\n    return text\n      .replace(/&lt;/g, \"<\")\n      .replace(/&gt;/g, \">\")\n      .replace(/&amp;/g, \"&\")\n      .replace(/&quot;/g, '\"')\n      .replace(/&#39;/g, \"'\");\n  }\n\n  // Storage for math expressions extracted before markdown processing\n  const mathExpressions: Map<\n    string,\n    { content: string; displayMode: boolean }\n  > = new Map();\n  let mathCounter = 0;\n\n  // Storage for HTML snippets that need protection from markdown\n  const htmlSnippets: Map<string, string> = new Map();\n  let htmlCounter = 0;\n\n  // Use alphanumeric placeholders that won't be interpreted as HTML tags\n  const MATH_PLACEHOLDER_PREFIX = \"MATHPLACEHOLDER\";\n  const CODE_PLACEHOLDER_PREFIX = \"CODEPLACEHOLDER\";\n  const HTML_PLACEHOLDER_PREFIX = \"HTMLPLACEHOLDER\";\n\n  /**\n   * Preprocess LaTeX: extract math, handle LaTeX document commands, and protect content\n   */\n  function preprocessLaTeX(text: string): string {\n    // Reset storage\n    mathExpressions.clear();\n    mathCounter = 0;\n    htmlSnippets.clear();\n    htmlCounter = 0;\n\n    // Protect code blocks first\n    const codeBlocks: string[] = [];\n    let processed = text.replace(/```[\\s\\S]*?```|`[^`]+`/g, (match) => {\n      codeBlocks.push(match);\n      return `${CODE_PLACEHOLDER_PREFIX}${codeBlocks.length - 1}END`;\n    });\n\n    // Remove LaTeX document commands\n    processed = processed.replace(/\\\\documentclass(\\[[^\\]]*\\])?\\{[^}]*\\}/g, \"\");\n    processed = processed.replace(/\\\\usepackage(\\[[^\\]]*\\])?\\{[^}]*\\}/g, \"\");\n    processed = processed.replace(/\\\\begin\\{document\\}/g, \"\");\n    processed = processed.replace(/\\\\end\\{document\\}/g, \"\");\n    processed = processed.replace(/\\\\maketitle/g, \"\");\n    processed = processed.replace(/\\\\title\\{[^}]*\\}/g, \"\");\n    processed = processed.replace(/\\\\author\\{[^}]*\\}/g, \"\");\n    processed = processed.replace(/\\\\date\\{[^}]*\\}/g, \"\");\n\n    // Remove \\require{...} commands (MathJax-specific, not supported by KaTeX)\n    processed = processed.replace(/\\$\\\\require\\{[^}]*\\}\\$/g, \"\");\n    processed = processed.replace(/\\\\require\\{[^}]*\\}/g, \"\");\n\n    // Remove unsupported LaTeX commands/environments (tikzpicture, figure, center, etc.)\n    processed = processed.replace(\n      /\\\\begin\\{tikzpicture\\}[\\s\\S]*?\\\\end\\{tikzpicture\\}/g,\n      () => {\n        const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;\n        htmlSnippets.set(\n          placeholder,\n          '<div class=\"latex-diagram-placeholder\"><span class=\"latex-diagram-icon\">📐</span><span class=\"latex-diagram-text\">Diagram</span></div>',\n        );\n        htmlCounter++;\n        return placeholder;\n      },\n    );\n    processed = processed.replace(\n      /\\\\begin\\{figure\\}[\\s\\S]*?\\\\end\\{figure\\}/g,\n      () => {\n        const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;\n        htmlSnippets.set(\n          placeholder,\n          '<div class=\"latex-diagram-placeholder\"><span class=\"latex-diagram-icon\">🖼️</span><span class=\"latex-diagram-text\">Figure</span></div>',\n        );\n        htmlCounter++;\n        return placeholder;\n      },\n    );\n    // Strip center environment (layout only, no content change)\n    processed = processed.replace(/\\\\begin\\{center\\}/g, \"\");\n    processed = processed.replace(/\\\\end\\{center\\}/g, \"\");\n    // Strip other layout environments\n    processed = processed.replace(/\\\\begin\\{flushleft\\}/g, \"\");\n    processed = processed.replace(/\\\\end\\{flushleft\\}/g, \"\");\n    processed = processed.replace(/\\\\begin\\{flushright\\}/g, \"\");\n    processed = processed.replace(/\\\\end\\{flushright\\}/g, \"\");\n    processed = processed.replace(/\\\\label\\{[^}]*\\}/g, \"\");\n    processed = processed.replace(/\\\\caption\\{[^}]*\\}/g, \"\");\n\n    // Protect escaped dollar signs (e.g., \\$50 should become $50, not LaTeX)\n    processed = processed.replace(/\\\\\\$/g, \"ESCAPEDDOLLARPLACEHOLDER\");\n\n    // Convert LaTeX math environments to display math (both bare and wrapped in $...$)\n    const mathEnvs = [\n      \"align\",\n      \"align\\\\*\",\n      \"equation\",\n      \"equation\\\\*\",\n      \"gather\",\n      \"gather\\\\*\",\n      \"multline\",\n      \"multline\\\\*\",\n      \"eqnarray\",\n      \"eqnarray\\\\*\",\n      \"array\",\n      \"matrix\",\n      \"pmatrix\",\n      \"bmatrix\",\n      \"vmatrix\",\n      \"cases\",\n    ];\n    for (const env of mathEnvs) {\n      // Handle $\\begin{env}...\\end{env}$ (with dollar signs, possibly multiline)\n      const wrappedRegex = new RegExp(\n        `\\\\$\\\\\\\\begin\\\\{${env}\\\\}(\\\\{[^}]*\\\\})?([\\\\s\\\\S]*?)\\\\\\\\end\\\\{${env}\\\\}\\\\$`,\n        \"g\",\n      );\n      processed = processed.replace(wrappedRegex, (_, args, content) => {\n        const cleanEnv = env.replace(\"\\\\*\", \"*\");\n        const mathContent = `\\\\begin{${cleanEnv}}${args || \"\"}${content}\\\\end{${cleanEnv}}`;\n        const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;\n        mathExpressions.set(placeholder, {\n          content: mathContent,\n          displayMode: true,\n        });\n        mathCounter++;\n        return placeholder;\n      });\n\n      // Handle bare \\begin{env}...\\end{env} (without dollar signs)\n      const bareRegex = new RegExp(\n        `\\\\\\\\begin\\\\{${env}\\\\}(\\\\{[^}]*\\\\})?([\\\\s\\\\S]*?)\\\\\\\\end\\\\{${env}\\\\}`,\n        \"g\",\n      );\n      processed = processed.replace(bareRegex, (_, args, content) => {\n        const cleanEnv = env.replace(\"\\\\*\", \"*\");\n        const mathContent = `\\\\begin{${cleanEnv}}${args || \"\"}${content}\\\\end{${cleanEnv}}`;\n        const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;\n        mathExpressions.set(placeholder, {\n          content: mathContent,\n          displayMode: true,\n        });\n        mathCounter++;\n        return placeholder;\n      });\n    }\n\n    // Convert LaTeX proof environments to styled blocks (use placeholders for HTML)\n    processed = processed.replace(\n      /\\\\begin\\{proof\\}([\\s\\S]*?)\\\\end\\{proof\\}/g,\n      (_, content) => {\n        const html = `<div class=\"latex-proof\"><div class=\"latex-proof-header\">Proof</div><div class=\"latex-proof-content\">${content}</div></div>`;\n        const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;\n        htmlSnippets.set(placeholder, html);\n        htmlCounter++;\n        return placeholder;\n      },\n    );\n\n    // Convert LaTeX theorem-like environments\n    const theoremEnvs = [\n      \"theorem\",\n      \"lemma\",\n      \"corollary\",\n      \"proposition\",\n      \"definition\",\n      \"remark\",\n      \"example\",\n    ];\n    for (const env of theoremEnvs) {\n      const envRegex = new RegExp(\n        `\\\\\\\\begin\\\\{${env}\\\\}([\\\\s\\\\S]*?)\\\\\\\\end\\\\{${env}\\\\}`,\n        \"gi\",\n      );\n      const envName = env.charAt(0).toUpperCase() + env.slice(1);\n      processed = processed.replace(envRegex, (_, content) => {\n        const html = `<div class=\"latex-theorem\"><div class=\"latex-theorem-header\">${envName}</div><div class=\"latex-theorem-content\">${content}</div></div>`;\n        const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;\n        htmlSnippets.set(placeholder, html);\n        htmlCounter++;\n        return placeholder;\n      });\n    }\n\n    // Convert LaTeX text formatting commands (use placeholders to protect from markdown)\n    processed = processed.replace(/\\\\emph\\{([^}]*)\\}/g, (_, content) => {\n      const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;\n      htmlSnippets.set(placeholder, `<em>${content}</em>`);\n      htmlCounter++;\n      return placeholder;\n    });\n    processed = processed.replace(/\\\\textit\\{([^}]*)\\}/g, (_, content) => {\n      const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;\n      htmlSnippets.set(placeholder, `<em>${content}</em>`);\n      htmlCounter++;\n      return placeholder;\n    });\n    processed = processed.replace(/\\\\textbf\\{([^}]*)\\}/g, (_, content) => {\n      const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;\n      htmlSnippets.set(placeholder, `<strong>${content}</strong>`);\n      htmlCounter++;\n      return placeholder;\n    });\n    processed = processed.replace(/\\\\texttt\\{([^}]*)\\}/g, (_, content) => {\n      const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;\n      htmlSnippets.set(\n        placeholder,\n        `<code class=\"inline-code\">${content}</code>`,\n      );\n      htmlCounter++;\n      return placeholder;\n    });\n    processed = processed.replace(/\\\\underline\\{([^}]*)\\}/g, (_, content) => {\n      const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;\n      htmlSnippets.set(placeholder, `<u>${content}</u>`);\n      htmlCounter++;\n      return placeholder;\n    });\n\n    // Handle LaTeX line breaks and spacing\n    processed = processed.replace(/\\\\\\\\(?:\\s*\\n)?/g, \"\\n\"); // \\\\ -> newline\n    processed = processed.replace(/\\\\newline/g, \"\\n\");\n    processed = processed.replace(/\\\\par\\b/g, \"\\n\\n\");\n    processed = processed.replace(/\\\\quad/g, \" \");\n    processed = processed.replace(/\\\\qquad/g, \"  \");\n    processed = processed.replace(/~~/g, \" \"); // non-breaking space\n\n    // Remove other common LaTeX commands that don't render\n    processed = processed.replace(/\\\\centering/g, \"\");\n    processed = processed.replace(/\\\\noindent/g, \"\");\n    processed = processed.replace(/\\\\hfill/g, \"\");\n    processed = processed.replace(/\\\\vspace\\{[^}]*\\}/g, \"\");\n    processed = processed.replace(/\\\\hspace\\{[^}]*\\}/g, \" \");\n\n    // Convert \\(...\\) to placeholder (display: false)\n    processed = processed.replace(/\\\\\\(([\\s\\S]+?)\\\\\\)/g, (_, content) => {\n      const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;\n      mathExpressions.set(placeholder, { content, displayMode: false });\n      mathCounter++;\n      return placeholder;\n    });\n\n    // Convert \\[...\\] to placeholder (display: true)\n    processed = processed.replace(/\\\\\\[([\\s\\S]*?)\\\\\\]/g, (_, content) => {\n      const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;\n      mathExpressions.set(placeholder, { content, displayMode: true });\n      mathCounter++;\n      return placeholder;\n    });\n\n    // Extract display math ($$...$$) BEFORE markdown processing\n    processed = processed.replace(/\\$\\$([\\s\\S]*?)\\$\\$/g, (_, content) => {\n      const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;\n      mathExpressions.set(placeholder, {\n        content: content.trim(),\n        displayMode: true,\n      });\n      mathCounter++;\n      return placeholder;\n    });\n\n    // Extract inline math ($...$) BEFORE markdown processing\n    // Allow single-line only, skip currency patterns like $5 or $50\n    processed = processed.replace(/\\$([^\\$\\n]+?)\\$/g, (match, content) => {\n      if (/^\\d/.test(content.trim())) {\n        return match; // Keep as-is for currency\n      }\n      const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;\n      mathExpressions.set(placeholder, {\n        content: content.trim(),\n        displayMode: false,\n      });\n      mathCounter++;\n      return placeholder;\n    });\n\n    // Restore escaped dollar signs\n    processed = processed.replace(/ESCAPEDDOLLARPLACEHOLDER/g, \"$\");\n\n    // Restore code blocks\n    processed = processed.replace(\n      new RegExp(`${CODE_PLACEHOLDER_PREFIX}(\\\\d+)END`, \"g\"),\n      (_, index) => codeBlocks[parseInt(index)],\n    );\n\n    // Clean up any remaining stray backslashes from unrecognized commands\n    processed = processed.replace(/\\\\(?=[a-zA-Z])/g, \"\"); // Remove \\ before letters (unrecognized commands)\n\n    return processed;\n  }\n\n  /**\n   * Render math expressions with KaTeX and restore HTML placeholders\n   */\n  function renderMath(html: string): string {\n    // Replace all math placeholders with rendered KaTeX\n    for (const [placeholder, { content, displayMode }] of mathExpressions) {\n      const escapedPlaceholder = placeholder.replace(\n        /[.*+?^${}()|[\\]\\\\]/g,\n        \"\\\\$&\",\n      );\n      const regex = new RegExp(escapedPlaceholder, \"g\");\n\n      html = html.replace(regex, () => {\n        try {\n          const rendered = katex.renderToString(content, {\n            displayMode,\n            throwOnError: false,\n            output: \"html\",\n          });\n\n          if (displayMode) {\n            return `\n\t\t\t\t\t\t\t<div class=\"math-display-wrapper\">\n\t\t\t\t\t\t\t\t<div class=\"math-display-header\">\n\t\t\t\t\t\t\t\t\t<span class=\"math-label\">LaTeX</span>\n\t\t\t\t\t\t\t\t\t<button type=\"button\" class=\"copy-math-btn\" data-math-source=\"${encodeURIComponent(content)}\" title=\"Copy LaTeX source\">\n\t\t\t\t\t\t\t\t\t\t<svg width=\"14\" height=\"14\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\">\n\t\t\t\t\t\t\t\t\t\t\t<rect width=\"14\" height=\"14\" x=\"8\" y=\"8\" rx=\"2\" ry=\"2\"/>\n\t\t\t\t\t\t\t\t\t\t\t<path d=\"M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2\"/>\n\t\t\t\t\t\t\t\t\t\t</svg>\n\t\t\t\t\t\t\t\t\t</button>\n\t\t\t\t\t\t\t\t</div>\n\t\t\t\t\t\t\t\t<div class=\"math-display-content\">\n\t\t\t\t\t\t\t\t\t${rendered}\n\t\t\t\t\t\t\t\t</div>\n\t\t\t\t\t\t\t</div>\n\t\t\t\t\t\t`;\n          } else {\n            return `<span class=\"math-inline\">${rendered}</span>`;\n          }\n        } catch {\n          const display = displayMode ? `$$${content}$$` : `$${content}$`;\n          return `<span class=\"math-error\"><span class=\"math-error-icon\">⚠</span> ${display}</span>`;\n        }\n      });\n    }\n\n    // Restore HTML placeholders (for \\textbf, \\emph, etc.)\n    for (const [placeholder, htmlContent] of htmlSnippets) {\n      const escapedPlaceholder = placeholder.replace(\n        /[.*+?^${}()|[\\]\\\\]/g,\n        \"\\\\$&\",\n      );\n      const regex = new RegExp(escapedPlaceholder, \"g\");\n      html = html.replace(regex, htmlContent);\n    }\n\n    return html;\n  }\n\n  function processMarkdown(text: string): string {\n    try {\n      // Preprocess LaTeX notation\n      const preprocessed = preprocessLaTeX(text);\n      // Parse markdown\n      let html = marked.parse(preprocessed) as string;\n      // Render math expressions\n      html = renderMath(html);\n      return html;\n    } catch (error) {\n      console.error(\"Markdown processing error:\", error);\n      return text.replace(/\\n/g, \"<br>\");\n    }\n  }\n\n  async function handleCopyClick(event: Event) {\n    const target = event.currentTarget as HTMLButtonElement;\n    const encodedCode = target.getAttribute(\"data-code\");\n    if (!encodedCode) return;\n\n    const code = decodeURIComponent(encodedCode);\n\n    try {\n      await navigator.clipboard.writeText(code);\n      // Show copied feedback\n      const originalHtml = target.innerHTML;\n      target.innerHTML = `\n\t\t\t\t<svg width=\"16\" height=\"16\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\">\n\t\t\t\t\t<path d=\"M20 6L9 17l-5-5\"/>\n\t\t\t\t</svg>\n\t\t\t`;\n      target.classList.add(\"copied\");\n      setTimeout(() => {\n        target.innerHTML = originalHtml;\n        target.classList.remove(\"copied\");\n      }, 2000);\n    } catch (error) {\n      console.error(\"Failed to copy:\", error);\n    }\n  }\n\n  async function handleMathCopyClick(event: Event) {\n    const target = event.currentTarget as HTMLButtonElement;\n    const encodedSource = target.getAttribute(\"data-math-source\");\n    if (!encodedSource) return;\n\n    const source = decodeURIComponent(encodedSource);\n\n    try {\n      await navigator.clipboard.writeText(source);\n      // Show copied feedback\n      const originalHtml = target.innerHTML;\n      target.innerHTML = `\n\t\t\t\t<svg width=\"14\" height=\"14\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\">\n\t\t\t\t\t<path d=\"M20 6L9 17l-5-5\"/>\n\t\t\t\t</svg>\n\t\t\t`;\n      target.classList.add(\"copied\");\n      setTimeout(() => {\n        target.innerHTML = originalHtml;\n        target.classList.remove(\"copied\");\n      }, 2000);\n    } catch (error) {\n      console.error(\"Failed to copy math:\", error);\n    }\n  }\n\n  function setupCopyButtons() {\n    if (!containerRef || !browser) return;\n\n    const codeButtons =\n      containerRef.querySelectorAll<HTMLButtonElement>(\".copy-code-btn\");\n    for (const button of codeButtons) {\n      if (button.dataset.listenerBound !== \"true\") {\n        button.dataset.listenerBound = \"true\";\n        button.addEventListener(\"click\", handleCopyClick);\n      }\n    }\n\n    const mathButtons =\n      containerRef.querySelectorAll<HTMLButtonElement>(\".copy-math-btn\");\n    for (const button of mathButtons) {\n      if (button.dataset.listenerBound !== \"true\") {\n        button.dataset.listenerBound = \"true\";\n        button.addEventListener(\"click\", handleMathCopyClick);\n      }\n    }\n  }\n\n  $effect(() => {\n    if (content) {\n      processedHtml = processMarkdown(content);\n    } else {\n      processedHtml = \"\";\n    }\n  });\n\n  $effect(() => {\n    if (!containerRef || !browser) return;\n\n    function handleDelegatedClick(event: MouseEvent) {\n      const codeBtn = (event.target as HTMLElement).closest(\n        \".copy-code-btn\",\n      ) as HTMLButtonElement | null;\n      if (codeBtn) {\n        handleCopyClick({ currentTarget: codeBtn } as unknown as Event);\n        return;\n      }\n      const mathBtn = (event.target as HTMLElement).closest(\n        \".copy-math-btn\",\n      ) as HTMLButtonElement | null;\n      if (mathBtn) {\n        handleMathCopyClick({ currentTarget: mathBtn } as unknown as Event);\n        return;\n      }\n    }\n\n    containerRef.addEventListener(\"click\", handleDelegatedClick);\n    return () => {\n      containerRef?.removeEventListener(\"click\", handleDelegatedClick);\n    };\n  });\n</script>\n\n<div bind:this={containerRef} class=\"markdown-content {className}\">\n  {@html processedHtml}\n</div>\n\n<style>\n  .markdown-content {\n    line-height: 1.6;\n  }\n\n  /* Paragraphs */\n  .markdown-content :global(p) {\n    margin-bottom: 1rem;\n  }\n\n  .markdown-content :global(p:last-child) {\n    margin-bottom: 0;\n  }\n\n  /* Headers */\n  .markdown-content :global(h1) {\n    font-size: 1.5rem;\n    font-weight: 700;\n    margin: 1.5rem 0 0.75rem 0;\n    color: var(--exo-yellow, #ffd700);\n  }\n\n  .markdown-content :global(h2) {\n    font-size: 1.25rem;\n    font-weight: 600;\n    margin: 1.25rem 0 0.5rem 0;\n    color: var(--exo-yellow, #ffd700);\n  }\n\n  .markdown-content :global(h3) {\n    font-size: 1.125rem;\n    font-weight: 600;\n    margin: 1rem 0 0.5rem 0;\n  }\n\n  .markdown-content :global(h4),\n  .markdown-content :global(h5),\n  .markdown-content :global(h6) {\n    font-size: 1rem;\n    font-weight: 600;\n    margin: 0.75rem 0 0.25rem 0;\n  }\n\n  /* Bold and italic */\n  .markdown-content :global(strong) {\n    font-weight: 600;\n  }\n\n  .markdown-content :global(em) {\n    font-style: italic;\n  }\n\n  /* Inline code */\n  .markdown-content :global(.inline-code) {\n    background: rgba(255, 215, 0, 0.1);\n    color: var(--exo-yellow, #ffd700);\n    padding: 0.125rem 0.375rem;\n    border-radius: 0.25rem;\n    font-family:\n      ui-monospace, SFMono-Regular, \"SF Mono\", Monaco, Consolas, monospace;\n    font-size: 0.875em;\n  }\n\n  /* Links */\n  .markdown-content :global(a) {\n    color: var(--exo-yellow, #ffd700);\n    text-decoration: underline;\n    text-underline-offset: 2px;\n  }\n\n  .markdown-content :global(a:hover) {\n    opacity: 0.8;\n  }\n\n  /* Lists */\n  .markdown-content :global(ul) {\n    list-style-type: disc;\n    margin-left: 1.5rem;\n    margin-bottom: 1rem;\n  }\n\n  .markdown-content :global(ol) {\n    list-style-type: decimal;\n    margin-left: 1.5rem;\n    margin-bottom: 1rem;\n  }\n\n  .markdown-content :global(li) {\n    margin-bottom: 0.25rem;\n  }\n\n  .markdown-content :global(li::marker) {\n    color: var(--exo-light-gray, #9ca3af);\n  }\n\n  /* Blockquotes */\n  .markdown-content :global(blockquote) {\n    border-left: 3px solid var(--exo-yellow, #ffd700);\n    padding: 0.5rem 1rem;\n    margin: 1rem 0;\n    background: rgba(255, 215, 0, 0.05);\n    border-radius: 0 0.25rem 0.25rem 0;\n  }\n\n  /* Tables */\n  .markdown-content :global(table) {\n    width: 100%;\n    margin: 1rem 0;\n    border-collapse: collapse;\n    font-size: 0.875rem;\n  }\n\n  .markdown-content :global(th) {\n    background: rgba(255, 215, 0, 0.1);\n    border: 1px solid rgba(255, 215, 0, 0.2);\n    padding: 0.5rem;\n    text-align: left;\n    font-weight: 600;\n  }\n\n  .markdown-content :global(td) {\n    border: 1px solid rgba(255, 255, 255, 0.1);\n    padding: 0.5rem;\n  }\n\n  /* Horizontal rule */\n  .markdown-content :global(hr) {\n    border: none;\n    border-top: 1px solid rgba(255, 255, 255, 0.1);\n    margin: 1.5rem 0;\n  }\n\n  /* Code block wrapper */\n  .markdown-content :global(.code-block-wrapper) {\n    margin: 1rem 0;\n    border-radius: 0.5rem;\n    overflow: hidden;\n    border: 1px solid rgba(255, 215, 0, 0.2);\n    background: rgba(0, 0, 0, 0.4);\n  }\n\n  .markdown-content :global(.code-block-header) {\n    display: flex;\n    justify-content: space-between;\n    align-items: center;\n    padding: 0.5rem 0.75rem;\n    background: rgba(255, 215, 0, 0.05);\n    border-bottom: 1px solid rgba(255, 215, 0, 0.1);\n  }\n\n  .markdown-content :global(.code-language) {\n    color: var(--exo-yellow, #ffd700);\n    font-size: 0.7rem;\n    font-weight: 500;\n    text-transform: uppercase;\n    letter-spacing: 0.1em;\n    font-family:\n      ui-monospace, SFMono-Regular, \"SF Mono\", Monaco, Consolas, monospace;\n  }\n\n  .markdown-content :global(.copy-code-btn) {\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    padding: 0.25rem;\n    background: transparent;\n    border: none;\n    color: var(--exo-light-gray, #9ca3af);\n    cursor: pointer;\n    transition: color 0.2s;\n    border-radius: 0.25rem;\n  }\n\n  .markdown-content :global(.copy-code-btn:hover) {\n    color: var(--exo-yellow, #ffd700);\n  }\n\n  .markdown-content :global(.copy-code-btn.copied) {\n    color: #22c55e;\n  }\n\n  .markdown-content :global(.code-block-wrapper pre) {\n    margin: 0;\n    padding: 1rem;\n    overflow-x: auto;\n    background: transparent;\n  }\n\n  .markdown-content :global(.code-block-wrapper code) {\n    font-family:\n      ui-monospace, SFMono-Regular, \"SF Mono\", Monaco, Consolas, monospace;\n    font-size: 0.8125rem;\n    line-height: 1.5;\n    background: transparent;\n  }\n\n  /* Syntax highlighting - dark theme matching EXO style */\n  .markdown-content :global(.hljs) {\n    color: #e5e7eb;\n  }\n\n  .markdown-content :global(.hljs-keyword),\n  .markdown-content :global(.hljs-selector-tag),\n  .markdown-content :global(.hljs-literal),\n  .markdown-content :global(.hljs-section),\n  .markdown-content :global(.hljs-link) {\n    color: #c084fc;\n  }\n\n  .markdown-content :global(.hljs-string),\n  .markdown-content :global(.hljs-title),\n  .markdown-content :global(.hljs-name),\n  .markdown-content :global(.hljs-type),\n  .markdown-content :global(.hljs-attribute),\n  .markdown-content :global(.hljs-symbol),\n  .markdown-content :global(.hljs-bullet),\n  .markdown-content :global(.hljs-addition),\n  .markdown-content :global(.hljs-variable),\n  .markdown-content :global(.hljs-template-tag),\n  .markdown-content :global(.hljs-template-variable) {\n    color: #fbbf24;\n  }\n\n  .markdown-content :global(.hljs-comment),\n  .markdown-content :global(.hljs-quote),\n  .markdown-content :global(.hljs-deletion),\n  .markdown-content :global(.hljs-meta) {\n    color: #6b7280;\n  }\n\n  .markdown-content :global(.hljs-number),\n  .markdown-content :global(.hljs-regexp),\n  .markdown-content :global(.hljs-literal),\n  .markdown-content :global(.hljs-built_in) {\n    color: #34d399;\n  }\n\n  .markdown-content :global(.hljs-function),\n  .markdown-content :global(.hljs-class .hljs-title) {\n    color: #60a5fa;\n  }\n\n  /* KaTeX math styling - Base */\n  .markdown-content :global(.katex) {\n    font-size: 1.1em;\n    color: oklch(0.9 0 0);\n  }\n\n  /* Display math container wrapper */\n  .markdown-content :global(.math-display-wrapper) {\n    margin: 1rem 0;\n    border-radius: 0.5rem;\n    overflow: hidden;\n    border: 1px solid rgba(255, 215, 0, 0.15);\n    background: rgba(0, 0, 0, 0.3);\n    transition:\n      border-color 0.2s ease,\n      box-shadow 0.2s ease;\n  }\n\n  .markdown-content :global(.math-display-wrapper:hover) {\n    border-color: rgba(255, 215, 0, 0.25);\n    box-shadow: 0 0 12px rgba(255, 215, 0, 0.08);\n  }\n\n  /* Display math header - hidden by default, slides in on hover */\n  .markdown-content :global(.math-display-header) {\n    display: flex;\n    justify-content: space-between;\n    align-items: center;\n    padding: 0.375rem 0.75rem;\n    background: rgba(255, 215, 0, 0.03);\n    border-bottom: 1px solid rgba(255, 215, 0, 0.08);\n    opacity: 0;\n    max-height: 0;\n    padding-top: 0;\n    padding-bottom: 0;\n    overflow: hidden;\n    transition:\n      opacity 0.2s ease,\n      max-height 0.2s ease,\n      padding 0.2s ease;\n  }\n\n  .markdown-content :global(.math-display-wrapper:hover .math-display-header) {\n    opacity: 1;\n    max-height: 2.5rem;\n    padding: 0.375rem 0.75rem;\n  }\n\n  .markdown-content :global(.math-label) {\n    color: rgba(255, 215, 0, 0.7);\n    font-size: 0.65rem;\n    font-weight: 500;\n    text-transform: uppercase;\n    letter-spacing: 0.1em;\n    font-family:\n      ui-monospace, SFMono-Regular, \"SF Mono\", Monaco, Consolas, monospace;\n  }\n\n  .markdown-content :global(.copy-math-btn) {\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    padding: 0.25rem;\n    background: transparent;\n    border: none;\n    color: var(--exo-light-gray, #9ca3af);\n    cursor: pointer;\n    transition: color 0.2s;\n    border-radius: 0.25rem;\n    opacity: 0;\n    transition:\n      color 0.2s,\n      opacity 0.15s ease;\n  }\n\n  .markdown-content :global(.math-display-wrapper:hover .copy-math-btn) {\n    opacity: 1;\n  }\n\n  .markdown-content :global(.copy-math-btn:hover) {\n    color: var(--exo-yellow, #ffd700);\n  }\n\n  .markdown-content :global(.copy-math-btn.copied) {\n    color: #22c55e;\n  }\n\n  /* Display math content area */\n  .markdown-content :global(.math-display-content) {\n    padding: 1rem 1.25rem;\n    overflow-x: auto;\n    overflow-y: hidden;\n  }\n\n  /* Custom scrollbar for math overflow */\n  .markdown-content :global(.math-display-content::-webkit-scrollbar) {\n    height: 6px;\n  }\n\n  .markdown-content :global(.math-display-content::-webkit-scrollbar-track) {\n    background: rgba(255, 255, 255, 0.05);\n    border-radius: 3px;\n  }\n\n  .markdown-content :global(.math-display-content::-webkit-scrollbar-thumb) {\n    background: rgba(255, 215, 0, 0.2);\n    border-radius: 3px;\n  }\n\n  .markdown-content\n    :global(.math-display-content::-webkit-scrollbar-thumb:hover) {\n    background: rgba(255, 215, 0, 0.35);\n  }\n\n  .markdown-content :global(.math-display-content .katex-display) {\n    margin: 0;\n    padding: 0;\n  }\n\n  .markdown-content :global(.math-display-content .katex-display > .katex) {\n    text-align: center;\n  }\n\n  /* Inline math wrapper */\n  .markdown-content :global(.math-inline) {\n    display: inline;\n    padding: 0 0.125rem;\n    border-radius: 0.25rem;\n    transition: background-color 0.15s ease;\n  }\n\n  .markdown-content :global(.math-inline:hover) {\n    background: rgba(255, 215, 0, 0.05);\n  }\n\n  /* Dark theme KaTeX overrides */\n  .markdown-content :global(.katex .mord),\n  .markdown-content :global(.katex .minner),\n  .markdown-content :global(.katex .mop),\n  .markdown-content :global(.katex .mbin),\n  .markdown-content :global(.katex .mrel),\n  .markdown-content :global(.katex .mpunct) {\n    color: oklch(0.9 0 0);\n  }\n\n  /* Fraction lines and rules */\n  .markdown-content :global(.katex .frac-line),\n  .markdown-content :global(.katex .overline-line),\n  .markdown-content :global(.katex .underline-line),\n  .markdown-content :global(.katex .hline),\n  .markdown-content :global(.katex .rule) {\n    border-color: oklch(0.85 0 0) !important;\n    background: oklch(0.85 0 0);\n  }\n\n  /* Square roots and SVG elements */\n  .markdown-content :global(.katex .sqrt-line) {\n    border-color: oklch(0.85 0 0) !important;\n  }\n\n  .markdown-content :global(.katex svg) {\n    fill: oklch(0.85 0 0);\n    stroke: oklch(0.85 0 0);\n  }\n\n  .markdown-content :global(.katex svg path) {\n    stroke: oklch(0.85 0 0);\n  }\n\n  /* Delimiters (parentheses, brackets, braces) */\n  .markdown-content :global(.katex .delimsizing),\n  .markdown-content :global(.katex .delim-size1),\n  .markdown-content :global(.katex .delim-size2),\n  .markdown-content :global(.katex .delim-size3),\n  .markdown-content :global(.katex .delim-size4),\n  .markdown-content :global(.katex .mopen),\n  .markdown-content :global(.katex .mclose) {\n    color: oklch(0.75 0 0);\n  }\n\n  /* Math error styling */\n  .markdown-content :global(.math-error) {\n    display: inline-flex;\n    align-items: center;\n    gap: 0.375rem;\n    color: #f87171;\n    font-family:\n      ui-monospace, SFMono-Regular, \"SF Mono\", Monaco, Consolas, monospace;\n    font-size: 0.875em;\n    background: rgba(248, 113, 113, 0.1);\n    padding: 0.25rem 0.5rem;\n    border-radius: 0.25rem;\n    border: 1px solid rgba(248, 113, 113, 0.2);\n  }\n\n  .markdown-content :global(.math-error-icon) {\n    font-size: 0.875em;\n    opacity: 0.9;\n  }\n\n  /* LaTeX proof environment */\n  .markdown-content :global(.latex-proof) {\n    margin: 1rem 0;\n    padding: 1rem 1.25rem;\n    background: rgba(255, 255, 255, 0.02);\n    border-left: 3px solid rgba(255, 215, 0, 0.4);\n    border-radius: 0 0.375rem 0.375rem 0;\n  }\n\n  .markdown-content :global(.latex-proof-header) {\n    font-weight: 600;\n    font-style: italic;\n    color: oklch(0.85 0 0);\n    margin-bottom: 0.5rem;\n  }\n\n  .markdown-content :global(.latex-proof-header::after) {\n    content: \".\";\n  }\n\n  .markdown-content :global(.latex-proof-content) {\n    color: oklch(0.9 0 0);\n  }\n\n  .markdown-content :global(.latex-proof-content p:last-child) {\n    margin-bottom: 0;\n  }\n\n  /* QED symbol at end of proof */\n  .markdown-content :global(.latex-proof-content::after) {\n    content: \"∎\";\n    display: block;\n    text-align: right;\n    color: oklch(0.7 0 0);\n    margin-top: 0.5rem;\n  }\n\n  /* LaTeX theorem-like environments */\n  .markdown-content :global(.latex-theorem) {\n    margin: 1rem 0;\n    padding: 1rem 1.25rem;\n    background: rgba(255, 215, 0, 0.03);\n    border: 1px solid rgba(255, 215, 0, 0.15);\n    border-radius: 0.375rem;\n  }\n\n  .markdown-content :global(.latex-theorem-header) {\n    font-weight: 700;\n    color: var(--exo-yellow, #ffd700);\n    margin-bottom: 0.5rem;\n  }\n\n  .markdown-content :global(.latex-theorem-header::after) {\n    content: \".\";\n  }\n\n  .markdown-content :global(.latex-theorem-content) {\n    color: oklch(0.9 0 0);\n    font-style: italic;\n  }\n\n  .markdown-content :global(.latex-theorem-content p:last-child) {\n    margin-bottom: 0;\n  }\n\n  /* LaTeX diagram/figure placeholder */\n  .markdown-content :global(.latex-diagram-placeholder) {\n    display: flex;\n    align-items: center;\n    justify-content: center;\n    gap: 0.5rem;\n    margin: 1rem 0;\n    padding: 1.5rem 2rem;\n    background: rgba(255, 255, 255, 0.02);\n    border: 1px dashed rgba(255, 215, 0, 0.25);\n    border-radius: 0.5rem;\n    color: rgba(255, 215, 0, 0.6);\n    font-size: 0.875rem;\n  }\n\n  .markdown-content :global(.latex-diagram-icon) {\n    font-size: 1.25rem;\n    opacity: 0.8;\n  }\n\n  .markdown-content :global(.latex-diagram-text) {\n    font-family:\n      ui-monospace, SFMono-Regular, \"SF Mono\", Monaco, Consolas, monospace;\n    font-size: 0.75rem;\n    text-transform: uppercase;\n    letter-spacing: 0.05em;\n  }\n</style>\n"
  },
  {
    "path": "dashboard/src/lib/components/ModelCard.svelte",
    "content": "<script lang=\"ts\">\n  import type {\n    DownloadProgress,\n    NodeInfo,\n    PlacementPreview,\n    TopologyEdge,\n  } from \"$lib/stores/app.svelte\";\n  import { debugMode, topologyData } from \"$lib/stores/app.svelte\";\n\n  interface Props {\n    model: { id: string; name?: string; storage_size_megabytes?: number };\n    isLaunching?: boolean;\n    downloadStatus?: {\n      isDownloading: boolean;\n      progress: DownloadProgress | null;\n      perNode?: Array<{\n        nodeId: string;\n        nodeName: string;\n        status: \"completed\" | \"partial\" | \"pending\" | \"downloading\";\n        percentage: number;\n        progress: DownloadProgress | null;\n      }>;\n    } | null;\n    nodes?: Record<string, NodeInfo>;\n    sharding?: \"Pipeline\" | \"Tensor\";\n    runtime?: \"MlxRing\" | \"MlxJaccl\";\n    onLaunch?: () => void;\n    tags?: string[];\n    apiPreview?: PlacementPreview | null;\n    modelIdOverride?: string | null;\n  }\n\n  let {\n    model,\n    isLaunching = false,\n    downloadStatus = null,\n    nodes = {},\n    sharding = \"Pipeline\",\n    runtime = \"MlxRing\",\n    onLaunch,\n    tags = [],\n    apiPreview = null,\n    modelIdOverride = null,\n  }: Props = $props();\n\n  // Estimate memory requirements from model name\n  // Uses regex with word boundaries to avoid false matches like '4bit' matching '4b'\n  function estimateMemoryGB(modelId: string, modelName?: string): number {\n    // Check both ID and name for quantization info\n    const combined = `${modelId} ${modelName || \"\"}`.toLowerCase();\n\n    // Detect quantization level - affects memory by roughly 2x between levels\n    const is4bit =\n      combined.includes(\"4bit\") ||\n      combined.includes(\"4-bit\") ||\n      combined.includes(\":4bit\");\n    const is8bit =\n      combined.includes(\"8bit\") ||\n      combined.includes(\"8-bit\") ||\n      combined.includes(\":8bit\");\n    // 4-bit = 0.5 bytes/param, 8-bit = 1 byte/param, fp16 = 2 bytes/param\n    const quantMultiplier = is4bit ? 0.5 : is8bit ? 1 : 2;\n    const id = modelId.toLowerCase();\n\n    // Known large models that don't follow the standard naming pattern\n    // DeepSeek V3 has 685B parameters\n    if (id.includes(\"deepseek-v3\")) {\n      return Math.round(685 * quantMultiplier);\n    }\n    // DeepSeek V2 has 236B parameters\n    if (id.includes(\"deepseek-v2\")) {\n      return Math.round(236 * quantMultiplier);\n    }\n    // Llama 4 Scout/Maverick are large models\n    if (id.includes(\"llama-4\")) {\n      return Math.round(400 * quantMultiplier);\n    }\n\n    // Match parameter counts with word boundaries (e.g., \"70b\" but not \"4bit\")\n    const paramMatch = id.match(/(\\d+(?:\\.\\d+)?)\\s*b(?![a-z])/i);\n    if (paramMatch) {\n      const params = parseFloat(paramMatch[1]);\n      return Math.max(4, Math.round(params * quantMultiplier));\n    }\n\n    // Fallback patterns for explicit size markers (assume fp16 baseline, adjust for quant)\n    if (id.includes(\"405b\") || id.includes(\"400b\"))\n      return Math.round(405 * quantMultiplier);\n    if (id.includes(\"180b\")) return Math.round(180 * quantMultiplier);\n    if (id.includes(\"141b\") || id.includes(\"140b\"))\n      return Math.round(140 * quantMultiplier);\n    if (id.includes(\"123b\") || id.includes(\"120b\"))\n      return Math.round(123 * quantMultiplier);\n    if (id.includes(\"72b\") || id.includes(\"70b\"))\n      return Math.round(70 * quantMultiplier);\n    if (id.includes(\"67b\") || id.includes(\"65b\"))\n      return Math.round(65 * quantMultiplier);\n    if (\n      id.includes(\"35b\") ||\n      id.includes(\"34b\") ||\n      id.includes(\"32b\") ||\n      id.includes(\"30b\")\n    )\n      return Math.round(32 * quantMultiplier);\n    if (id.includes(\"27b\") || id.includes(\"26b\") || id.includes(\"22b\"))\n      return Math.round(24 * quantMultiplier);\n    if (id.includes(\"14b\") || id.includes(\"13b\") || id.includes(\"15b\"))\n      return Math.round(14 * quantMultiplier);\n    if (id.includes(\"8b\") || id.includes(\"9b\") || id.includes(\"7b\"))\n      return Math.round(8 * quantMultiplier);\n    if (id.includes(\"3b\") || id.includes(\"3.8b\"))\n      return Math.round(4 * quantMultiplier);\n    if (\n      id.includes(\"2b\") ||\n      id.includes(\"1b\") ||\n      id.includes(\"1.5b\") ||\n      id.includes(\"0.5b\")\n    )\n      return Math.round(2 * quantMultiplier);\n\n    return 16; // Default fallback\n  }\n\n  function formatBytes(bytes: number, decimals = 1): string {\n    if (!bytes || bytes === 0) return \"0 B\";\n    const k = 1024;\n    const sizes = [\"B\", \"KB\", \"MB\", \"GB\", \"TB\"];\n    const i = Math.floor(Math.log(bytes) / Math.log(k));\n    return (\n      parseFloat((bytes / Math.pow(k, i)).toFixed(decimals)) + \" \" + sizes[i]\n    );\n  }\n\n  function formatSpeed(bps: number): string {\n    if (!bps || bps <= 0) return \"0 B/s\";\n    return formatBytes(bps) + \"/s\";\n  }\n\n  function formatEta(ms: number): string {\n    if (!ms || ms <= 0) return \"--\";\n    const totalSeconds = Math.round(ms / 1000);\n    const s = totalSeconds % 60;\n    const m = Math.floor(totalSeconds / 60) % 60;\n    const h = Math.floor(totalSeconds / 3600);\n    if (h > 0) return `${h}h ${m}m`;\n    if (m > 0) return `${m}m ${s}s`;\n    return `${s}s`;\n  }\n\n  const perNode = $derived(downloadStatus?.perNode ?? []);\n\n  function toggleNodeDetails(nodeId: string): void {\n    const next = new Set(expandedNodes);\n    if (next.has(nodeId)) {\n      next.delete(nodeId);\n    } else {\n      next.add(nodeId);\n    }\n    expandedNodes = next;\n  }\n\n  // Use actual storage_size_megabytes from API if available, otherwise fall back to estimate\n  const estimatedMemory = $derived(\n    model.storage_size_megabytes\n      ? Math.round(model.storage_size_megabytes / 1024)\n      : estimateMemoryGB(model.id, model.name),\n  );\n\n  function getDeviceType(\n    name: string,\n  ): \"macbook\" | \"studio\" | \"mini\" | \"unknown\" {\n    const lower = name.toLowerCase();\n    if (lower.includes(\"macbook\")) return \"macbook\";\n    if (lower.includes(\"studio\")) return \"studio\";\n    if (lower.includes(\"mini\")) return \"mini\";\n    return \"unknown\";\n  }\n\n  const clampPercent = (value: number): number =>\n    Math.min(100, Math.max(0, value));\n  const huggingFaceModelId = $derived(modelIdOverride ?? model.id);\n\n  // Get node list in the same order as the topology graph (insertion order of\n  // topology nodes), while still ensuring preview nodes render even if the\n  // topology payload is missing them. Topology order is preserved exactly so\n  // that the mini preview matches the main TopologyGraph layout.\n  const nodeList = $derived(() => {\n    const nodesFromTopology = Object.keys(nodes).map((id) => {\n      const info = nodes[id];\n      const totalBytes =\n        info.macmon_info?.memory?.ram_total ?? info.system_info?.memory ?? 0;\n      const usedBytes = info.macmon_info?.memory?.ram_usage ?? 0;\n      const availableBytes = Math.max(totalBytes - usedBytes, 0);\n      const totalGB = totalBytes / (1024 * 1024 * 1024);\n      const availableGB = availableBytes / (1024 * 1024 * 1024);\n      const usedGB = Math.max(totalGB - availableGB, 0);\n      const deviceName = info.system_info?.model_id ?? \"Unknown\";\n      const deviceType = getDeviceType(deviceName);\n\n      return {\n        id,\n        totalGB,\n        availableGB,\n        usedGB,\n        deviceName,\n        deviceType,\n        usedBytes,\n        totalBytes,\n      };\n    });\n\n    const previewEntries = apiPreview?.memory_delta_by_node ?? null;\n    const previewIds = previewEntries ? Object.keys(previewEntries) : [];\n\n    if (previewIds.length === 0) return nodesFromTopology;\n\n    // Append any preview-only nodes (not in topology) at the end\n    const topologyIds = new Set(nodesFromTopology.map((n) => n.id));\n    const extraPreviewNodes = previewIds\n      .filter((id) => !topologyIds.has(id))\n      .map((id) => {\n        const deltaBytes = previewEntries?.[id] ?? 0;\n        const deltaGB = deltaBytes / (1024 * 1024 * 1024);\n        const totalGB = Math.max(deltaGB * 1.2, 1);\n        const usedGB = Math.max(totalGB - deltaGB, 0);\n\n        return {\n          id,\n          totalGB,\n          availableGB: Math.max(totalGB - usedGB, 0),\n          usedGB,\n          deviceName: \"Unknown\",\n          deviceType: \"unknown\" as const,\n          usedBytes: usedGB * 1024 * 1024 * 1024,\n          totalBytes: totalGB * 1024 * 1024 * 1024,\n        };\n      });\n\n    return [...nodesFromTopology, ...extraPreviewNodes];\n  });\n\n  // Calculate placement preview with all SVG metrics pre-computed\n  // Uses API preview data when available, falls back to local estimation\n  const placementPreview = $derived(() => {\n    const nodeArray = nodeList();\n    if (nodeArray.length === 0)\n      return {\n        nodes: [],\n        canFit: false,\n        totalAvailable: 0,\n        topoWidth: 260,\n        topoHeight: 90,\n        error: null,\n      };\n\n    const numNodes = nodeArray.length;\n    const iconSize = numNodes === 1 ? 50 : 36;\n    const topoWidth = 260;\n    const topoHeight =\n      numNodes === 1 ? 90 : numNodes === 2 ? 140 : numNodes * 50 + 20;\n    const centerX = topoWidth / 2;\n    const centerY = topoHeight / 2;\n    const radius =\n      numNodes === 1\n        ? 0\n        : numNodes === 2\n          ? 45\n          : Math.min(topoWidth, topoHeight) * 0.32;\n\n    // Only use API preview data - no local estimation\n    const hasApiPreview =\n      apiPreview !== null &&\n      apiPreview.error === null &&\n      apiPreview.memory_delta_by_node !== null;\n    const error = apiPreview?.error ?? null;\n\n    let placementNodes: Array<{\n      id: string;\n      deviceName: string;\n      deviceType: \"macbook\" | \"studio\" | \"mini\" | \"unknown\";\n      totalGB: number;\n      currentUsedGB: number;\n      modelUsageGB: number;\n      currentPercent: number;\n      newPercent: number;\n      isUsed: boolean;\n      x: number;\n      y: number;\n      iconSize: number;\n      screenHeight: number;\n      currentFillHeight: number;\n      modelFillHeight: number;\n    }> = [];\n\n    // Use API placement data directly\n    const memoryDelta = apiPreview?.memory_delta_by_node ?? {};\n    placementNodes = nodeArray.map((n, i) => {\n      const deltaBytes = memoryDelta[n.id] ?? 0;\n      const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);\n      const isUsed = deltaBytes > 0;\n      const angle =\n        numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;\n      const safeTotal = Math.max(n.totalGB, 0.001);\n      const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);\n      const newPercent = clampPercent(\n        ((n.usedGB + modelUsageGB) / safeTotal) * 100,\n      );\n      const screenHeight = iconSize * 0.58;\n\n      return {\n        id: n.id,\n        deviceName: n.deviceName,\n        deviceType: n.deviceType,\n        totalGB: n.totalGB,\n        currentUsedGB: n.usedGB,\n        modelUsageGB,\n        currentPercent,\n        newPercent,\n        isUsed,\n        x: centerX + Math.cos(angle) * radius,\n        y: centerY + Math.sin(angle) * radius,\n        iconSize,\n        screenHeight,\n        currentFillHeight: screenHeight * (currentPercent / 100),\n        modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100),\n      };\n    });\n\n    const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);\n    return {\n      nodes: placementNodes,\n      canFit: hasApiPreview,\n      totalAvailable,\n      topoWidth,\n      topoHeight,\n      error,\n    };\n  });\n\n  const canFit = $derived(\n    apiPreview ? apiPreview.error === null : placementPreview().canFit,\n  );\n  const placementError = $derived(apiPreview?.error ?? null);\n  const nodeCount = $derived(nodeList().length);\n  const filterId = $derived(model.id.replace(/[^a-zA-Z0-9]/g, \"\"));\n\n  // Debug mode state\n  const isDebugMode = $derived(debugMode());\n  const topology = $derived(topologyData());\n  const isRdma = $derived(runtime === \"MlxJaccl\");\n\n  // Get interface name for an IP from node data\n  function getInterfaceForIp(nodeId: string, ip?: string): string | null {\n    if (!ip || !topology?.nodes) return null;\n\n    // Strip port if present\n    const cleanIp =\n      ip.includes(\":\") && !ip.includes(\"[\") ? ip.split(\":\")[0] : ip;\n\n    // Check specified node first\n    const node = topology.nodes[nodeId];\n    if (node) {\n      const match = node.network_interfaces?.find((iface) =>\n        (iface.addresses || []).some((addr) => addr === cleanIp || addr === ip),\n      );\n      if (match?.name) return match.name;\n\n      const mapped =\n        node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];\n      if (mapped) return mapped;\n    }\n\n    // Fallback: check all nodes\n    for (const [, otherNode] of Object.entries(topology.nodes)) {\n      if (!otherNode) continue;\n      const match = otherNode.network_interfaces?.find((iface) =>\n        (iface.addresses || []).some((addr) => addr === cleanIp || addr === ip),\n      );\n      if (match?.name) return match.name;\n\n      const mapped =\n        otherNode.ip_to_interface?.[cleanIp] || otherNode.ip_to_interface?.[ip];\n      if (mapped) return mapped;\n    }\n\n    return null;\n  }\n\n  // Get directional arrow based on node positions\n  function getArrow(\n    fromNode: { x: number; y: number },\n    toNode: { x: number; y: number },\n  ): string {\n    const dx = toNode.x - fromNode.x;\n    const dy = toNode.y - fromNode.y;\n    const absX = Math.abs(dx);\n    const absY = Math.abs(dy);\n\n    if (absX > absY * 2) {\n      return dx > 0 ? \"→\" : \"←\";\n    } else if (absY > absX * 2) {\n      return dy > 0 ? \"↓\" : \"↑\";\n    } else {\n      if (dx > 0 && dy > 0) return \"↘\";\n      if (dx > 0 && dy < 0) return \"↗\";\n      if (dx < 0 && dy > 0) return \"↙\";\n      return \"↖\";\n    }\n  }\n\n  // Get connection info for edges between two nodes\n  // Returns exactly one connection per direction (A→B and B→A), preferring non-loopback\n  function getConnectionInfo(\n    nodeId1: string,\n    nodeId2: string,\n  ): Array<{ ip: string; iface: string | null; from: string; to: string }> {\n    if (!topology?.edges) return [];\n\n    // Collect candidates for each direction\n    const aToBCandidates: Array<{ ip: string; iface: string | null }> = [];\n    const bToACandidates: Array<{ ip: string; iface: string | null }> = [];\n\n    for (const edge of topology.edges) {\n      let ip: string;\n      let iface: string | null;\n\n      if (edge.sourceRdmaIface || edge.sinkRdmaIface) {\n        ip = \"RDMA\";\n        iface = `${edge.sourceRdmaIface || \"?\"} \\u2192 ${edge.sinkRdmaIface || \"?\"}`;\n      } else {\n        ip = edge.sendBackIp || \"?\";\n        iface = edge.sendBackInterface || getInterfaceForIp(edge.source, ip);\n      }\n\n      if (edge.source === nodeId1 && edge.target === nodeId2) {\n        aToBCandidates.push({ ip, iface });\n      } else if (edge.source === nodeId2 && edge.target === nodeId1) {\n        bToACandidates.push({ ip, iface });\n      }\n    }\n\n    // Pick best (prefer non-loopback)\n    const pickBest = (\n      candidates: Array<{ ip: string; iface: string | null }>,\n    ) => {\n      if (candidates.length === 0) return null;\n      return candidates.find((c) => !c.ip.startsWith(\"127.\")) || candidates[0];\n    };\n\n    const result: Array<{\n      ip: string;\n      iface: string | null;\n      from: string;\n      to: string;\n    }> = [];\n\n    const bestAtoB = pickBest(aToBCandidates);\n    if (bestAtoB) result.push({ ...bestAtoB, from: nodeId1, to: nodeId2 });\n\n    const bestBtoA = pickBest(bToACandidates);\n    if (bestBtoA) result.push({ ...bestBtoA, from: nodeId2, to: nodeId1 });\n\n    return result;\n  }\n</script>\n\n<div class=\"relative group\">\n  <!-- Corner accents -->\n  <div\n    class=\"absolute -top-px -left-px w-2 h-2 border-l border-t {canFit\n      ? 'border-exo-yellow/30 group-hover:border-exo-yellow/60'\n      : 'border-red-500/30'} transition-colors\"\n  ></div>\n  <div\n    class=\"absolute -top-px -right-px w-2 h-2 border-r border-t {canFit\n      ? 'border-exo-yellow/30 group-hover:border-exo-yellow/60'\n      : 'border-red-500/30'} transition-colors\"\n  ></div>\n  <div\n    class=\"absolute -bottom-px -left-px w-2 h-2 border-l border-b {canFit\n      ? 'border-exo-yellow/30 group-hover:border-exo-yellow/60'\n      : 'border-red-500/30'} transition-colors\"\n  ></div>\n  <div\n    class=\"absolute -bottom-px -right-px w-2 h-2 border-r border-b {canFit\n      ? 'border-exo-yellow/30 group-hover:border-exo-yellow/60'\n      : 'border-red-500/30'} transition-colors\"\n  ></div>\n\n  <div\n    class=\"bg-exo-dark-gray/60 border {canFit\n      ? 'border-exo-yellow/20 group-hover:border-exo-yellow/40'\n      : 'border-red-500/20'} p-3 transition-all duration-200 group-hover:shadow-[0_0_15px_rgba(255,215,0,0.1)]\"\n  >\n    <!-- Model Name & Memory Required -->\n    <div class=\"flex items-start justify-between gap-2 mb-2\">\n      <div class=\"flex-1 min-w-0\">\n        <div class=\"flex items-center gap-2\">\n          <div\n            class=\"text-exo-yellow text-xs font-mono tracking-wide truncate\"\n            title={model.name || model.id}\n          >\n            {model.name || model.id}\n          </div>\n          {#if huggingFaceModelId}\n            <a\n              class=\"shrink-0 text-white/60 hover:text-exo-yellow transition-colors\"\n              href={`https://huggingface.co/${huggingFaceModelId}`}\n              target=\"_blank\"\n              rel=\"noreferrer noopener\"\n              aria-label=\"View model on Hugging Face\"\n            >\n              <svg\n                class=\"w-3.5 h-3.5\"\n                viewBox=\"0 0 24 24\"\n                fill=\"none\"\n                stroke=\"currentColor\"\n                stroke-width=\"2\"\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n              >\n                <path d=\"M14 3h7v7\" />\n                <path d=\"M10 14l11-11\" />\n                <path\n                  d=\"M21 14v6a1 1 0 0 1-1 1h-16a1 1 0 0 1-1-1v-16a1 1 0 0 1 1-1h6\"\n                />\n              </svg>\n            </a>\n          {/if}\n          {#if tags.length > 0}\n            <div class=\"flex gap-1 flex-shrink-0\">\n              {#each tags as tag}\n                <span\n                  class=\"px-1.5 py-0.5 text-xs font-mono tracking-wider uppercase rounded {tag ===\n                  'FASTEST'\n                    ? 'bg-green-500/20 text-green-400 border border-green-500/30'\n                    : 'bg-purple-500/20 text-purple-400 border border-purple-500/30'}\"\n                >\n                  {tag}\n                </span>\n              {/each}\n            </div>\n          {/if}\n        </div>\n        {#if model.name && model.name !== model.id}\n          <div\n            class=\"text-xs text-exo-light-gray font-mono truncate mt-0.5\"\n            title={model.id}\n          >\n            {model.id}\n          </div>\n        {/if}\n      </div>\n      <div class=\"flex-shrink-0 text-right\">\n        <div\n          class=\"text-xs font-mono {canFit\n            ? 'text-exo-yellow'\n            : 'text-red-400'}\"\n        >\n          {estimatedMemory}GB\n        </div>\n      </div>\n    </div>\n\n    <!-- Configuration Badge -->\n    <div class=\"flex items-center gap-1.5 mb-2\">\n      <span\n        class=\"px-1.5 py-0.5 text-xs font-mono tracking-wider uppercase bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/40\"\n        title={sharding === \"Pipeline\"\n          ? \"Pipeline: splits model into sequential stages across devices. Lower network overhead.\"\n          : \"Tensor: splits each layer across devices. Best with high-bandwidth connections (Thunderbolt).\"}\n      >\n        {sharding}\n      </span>\n      <span\n        class=\"px-1.5 py-0.5 text-xs font-mono tracking-wider uppercase bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/40\"\n        title={runtime === \"MlxRing\"\n          ? \"Ring: standard networking. Works over any connection (Wi-Fi, Ethernet, Thunderbolt).\"\n          : \"RDMA: direct memory access over Thunderbolt. Significantly faster for multi-device inference.\"}\n      >\n        {runtime === \"MlxRing\"\n          ? \"MLX Ring\"\n          : runtime === \"MlxJaccl\"\n            ? \"MLX RDMA\"\n            : runtime}\n      </span>\n    </div>\n\n    <!-- Download Status (per-node) -->\n    {#if perNode.length > 0}\n      <div class=\"mb-2 space-y-1\">\n        <div\n          class=\"text-[10px] font-mono text-white/20 tracking-widest uppercase\"\n        >\n          Download progress\n        </div>\n        {#each perNode as node}\n          <div class=\"flex items-center gap-2 text-xs font-mono\">\n            <span class=\"text-white/40 w-20 truncate\" title={node.nodeId}\n              >{node.nodeName}</span\n            >\n            <div\n              class=\"flex-1 h-1 bg-exo-medium-gray/30 rounded overflow-hidden\"\n            >\n              <div\n                class=\"h-full transition-all duration-300 {node.status ===\n                'downloading'\n                  ? 'bg-blue-500/70'\n                  : node.status === 'completed'\n                    ? 'bg-exo-yellow/40'\n                    : 'bg-white/20'}\"\n                style=\"width: {node.percentage}%\"\n              ></div>\n            </div>\n            <span\n              class=\"text-right {node.status === 'completed'\n                ? 'text-exo-yellow/60'\n                : node.status === 'downloading'\n                  ? 'text-blue-400/60'\n                  : 'text-white/30'}\"\n            >\n              {#if node.status === \"downloading\" && node.progress}\n                {Math.round(node.percentage)}% {formatSpeed(\n                  node.progress.speed,\n                )}\n              {:else}\n                {node.percentage > 0 ? `${Math.round(node.percentage)}%` : \"0%\"}\n              {/if}\n            </span>\n          </div>\n        {/each}\n      </div>\n    {/if}\n\n    <!-- Mini Topology Preview -->\n    {#if placementPreview().nodes.length > 0}\n      {@const preview = placementPreview()}\n      <div\n        class=\"mb-3 bg-exo-black/60 rounded border border-exo-medium-gray/20 p-2 relative overflow-hidden\"\n      >\n        <!-- Scanline effect -->\n        <div\n          class=\"absolute inset-0 bg-[repeating-linear-gradient(0deg,transparent,transparent_2px,rgba(255,215,0,0.02)_2px,rgba(255,215,0,0.02)_4px)] pointer-events-none\"\n        ></div>\n\n        <svg\n          width=\"100%\"\n          height={preview.topoHeight}\n          viewBox=\"0 0 {preview.topoWidth} {preview.topoHeight}\"\n          class=\"overflow-visible\"\n        >\n          <defs>\n            <!-- Glow filter for active nodes -->\n            <filter\n              id=\"nodeGlow-{filterId}\"\n              x=\"-50%\"\n              y=\"-50%\"\n              width=\"200%\"\n              height=\"200%\"\n            >\n              <feGaussianBlur stdDeviation=\"2\" result=\"blur\" />\n              <feMerge>\n                <feMergeNode in=\"blur\" />\n                <feMergeNode in=\"SourceGraphic\" />\n              </feMerge>\n            </filter>\n\n            <!-- Strong glow for new memory -->\n            <filter\n              id=\"memGlow-{filterId}\"\n              x=\"-100%\"\n              y=\"-100%\"\n              width=\"300%\"\n              height=\"300%\"\n            >\n              <feGaussianBlur stdDeviation=\"3\" result=\"blur\" />\n              <feComposite in=\"SourceGraphic\" in2=\"blur\" operator=\"over\" />\n            </filter>\n          </defs>\n\n          <!-- Connection lines between nodes (if multiple) -->\n          {#if preview.nodes.length > 1}\n            {@const usedNodes = preview.nodes.filter((n) => n.isUsed)}\n            {@const nodePositions = Object.fromEntries(\n              preview.nodes.map((n) => [n.id, { x: n.x, y: n.y }]),\n            )}\n            {@const allConnections =\n              isDebugMode && usedNodes.length > 1\n                ? (() => {\n                    const conns: Array = [];\n                    for (let i = 0; i < usedNodes.length; i++) {\n                      for (let j = i + 1; j < usedNodes.length; j++) {\n                        const n1 = usedNodes[i];\n                        const n2 = usedNodes[j];\n                        const midX = (n1.x + n2.x) / 2;\n                        const midY = (n1.y + n2.y) / 2;\n                        for (const c of getConnectionInfo(n1.id, n2.id)) {\n                          const fromPos = nodePositions[c.from];\n                          const toPos = nodePositions[c.to];\n                          const arrow =\n                            fromPos && toPos ? getArrow(fromPos, toPos) : \"→\";\n                          conns.push({\n                            ...c,\n                            midX,\n                            midY,\n                            arrow,\n                          });\n                        }\n                      }\n                    }\n                    return conns;\n                  })()\n                : []}\n            {#each preview.nodes as node, i}\n              {#each preview.nodes.slice(i + 1) as node2}\n                <line\n                  x1={node.x}\n                  y1={node.y}\n                  x2={node2.x}\n                  y2={node2.y}\n                  stroke={node.isUsed && node2.isUsed ? \"#FFD700\" : \"#374151\"}\n                  stroke-width=\"1\"\n                  stroke-dasharray={node.isUsed && node2.isUsed ? \"4,2\" : \"2,4\"}\n                  opacity={node.isUsed && node2.isUsed ? 0.4 : 0.15}\n                />\n              {/each}\n            {/each}\n            <!-- Debug: Show connection IPs/interfaces in corners -->\n            {#if isDebugMode && allConnections.length > 0}\n              {@const centerX = preview.topoWidth / 2}\n              {@const centerY = preview.topoHeight / 2}\n              {@const quadrants = {\n                topLeft: allConnections.filter(\n                  (c) => c.midX < centerX && c.midY < centerY,\n                ),\n                topRight: allConnections.filter(\n                  (c) => c.midX >= centerX && c.midY < centerY,\n                ),\n                bottomLeft: allConnections.filter(\n                  (c) => c.midX < centerX && c.midY >= centerY,\n                ),\n                bottomRight: allConnections.filter(\n                  (c) => c.midX >= centerX && c.midY >= centerY,\n                ),\n              }}\n              {@const padding = 4}\n              {@const lineHeight = 8}\n              <!-- Top Left -->\n              {#each quadrants.topLeft as conn, idx}\n                <text\n                  x={padding}\n                  y={padding + idx * lineHeight}\n                  text-anchor=\"start\"\n                  dominant-baseline=\"hanging\"\n                  font-size=\"6\"\n                  font-family=\"SF Mono, Monaco, monospace\"\n                  fill={conn.iface\n                    ? \"rgba(255,255,255,0.85)\"\n                    : \"rgba(248,113,113,0.85)\"}\n                >\n                  {conn.arrow}\n                  {isRdma\n                    ? conn.iface || \"?\"\n                    : `${conn.ip}${conn.iface ? ` (${conn.iface})` : \"\"}`}\n                </text>\n              {/each}\n              <!-- Top Right -->\n              {#each quadrants.topRight as conn, idx}\n                <text\n                  x={preview.topoWidth - padding}\n                  y={padding + idx * lineHeight}\n                  text-anchor=\"end\"\n                  dominant-baseline=\"hanging\"\n                  font-size=\"6\"\n                  font-family=\"SF Mono, Monaco, monospace\"\n                  fill={conn.iface\n                    ? \"rgba(255,255,255,0.85)\"\n                    : \"rgba(248,113,113,0.85)\"}\n                >\n                  {conn.arrow}\n                  {isRdma\n                    ? conn.iface || \"?\"\n                    : `${conn.ip}${conn.iface ? ` (${conn.iface})` : \"\"}`}\n                </text>\n              {/each}\n              <!-- Bottom Left -->\n              {#each quadrants.bottomLeft as conn, idx}\n                <text\n                  x={padding}\n                  y={preview.topoHeight -\n                    padding -\n                    (quadrants.bottomLeft.length - 1 - idx) * lineHeight}\n                  text-anchor=\"start\"\n                  dominant-baseline=\"auto\"\n                  font-size=\"6\"\n                  font-family=\"SF Mono, Monaco, monospace\"\n                  fill={conn.iface\n                    ? \"rgba(255,255,255,0.85)\"\n                    : \"rgba(248,113,113,0.85)\"}\n                >\n                  {conn.arrow}\n                  {isRdma\n                    ? conn.iface || \"?\"\n                    : `${conn.ip}${conn.iface ? ` (${conn.iface})` : \"\"}`}\n                </text>\n              {/each}\n              <!-- Bottom Right -->\n              {#each quadrants.bottomRight as conn, idx}\n                <text\n                  x={preview.topoWidth - padding}\n                  y={preview.topoHeight -\n                    padding -\n                    (quadrants.bottomRight.length - 1 - idx) * lineHeight}\n                  text-anchor=\"end\"\n                  dominant-baseline=\"auto\"\n                  font-size=\"6\"\n                  font-family=\"SF Mono, Monaco, monospace\"\n                  fill={conn.iface\n                    ? \"rgba(255,255,255,0.85)\"\n                    : \"rgba(248,113,113,0.85)\"}\n                >\n                  {conn.arrow}\n                  {isRdma\n                    ? conn.iface || \"?\"\n                    : `${conn.ip}${conn.iface ? ` (${conn.iface})` : \"\"}`}\n                </text>\n              {/each}\n            {/if}\n          {/if}\n\n          {#each preview.nodes as node}\n            <g\n              transform=\"translate({node.x}, {node.y})\"\n              opacity={node.isUsed ? 1 : 0.25}\n              filter={node.isUsed ? `url(#nodeGlow-${filterId})` : \"none\"}\n            >\n              <!-- Device icon based on type -->\n              {#if node.deviceType === \"macbook\"}\n                <!-- MacBook Pro icon with memory fill -->\n                <g\n                  transform=\"translate({-node.iconSize / 2}, {-node.iconSize /\n                    2})\"\n                >\n                  <!-- Screen bezel -->\n                  <rect\n                    x=\"2\"\n                    y=\"0\"\n                    width={node.iconSize - 4}\n                    height={node.iconSize * 0.65}\n                    rx=\"2\"\n                    fill=\"none\"\n                    stroke={node.isUsed ? \"#FFD700\" : \"#4B5563\"}\n                    stroke-width=\"1.5\"\n                  />\n                  <!-- Screen area (memory fill container) -->\n                  <rect\n                    x=\"4\"\n                    y=\"2\"\n                    width={node.iconSize - 8}\n                    height={node.screenHeight}\n                    fill=\"#0a0a0a\"\n                  />\n                  <!-- Current memory fill (gray) -->\n                  <rect\n                    x=\"4\"\n                    y={2 + node.screenHeight - node.currentFillHeight}\n                    width={node.iconSize - 8}\n                    height={node.currentFillHeight}\n                    fill=\"#374151\"\n                  />\n                  <!-- New model memory fill (glowing yellow) -->\n                  {#if node.modelUsageGB > 0 && node.isUsed}\n                    <rect\n                      x=\"4\"\n                      y={2 +\n                        node.screenHeight -\n                        node.currentFillHeight -\n                        node.modelFillHeight}\n                      width={node.iconSize - 8}\n                      height={node.modelFillHeight}\n                      fill=\"#FFD700\"\n                      filter=\"url(#memGlow-{filterId})\"\n                      class=\"animate-pulse-slow\"\n                    />\n                  {/if}\n                  <!-- Base/keyboard -->\n                  <path\n                    d=\"M 0 {node.iconSize *\n                      0.68} L {node.iconSize} {node.iconSize *\n                      0.68} L {node.iconSize - 2} {node.iconSize *\n                      0.78} L 2 {node.iconSize * 0.78} Z\"\n                    fill=\"none\"\n                    stroke={node.isUsed ? \"#FFD700\" : \"#4B5563\"}\n                    stroke-width=\"1.5\"\n                  />\n                </g>\n              {:else if node.deviceType === \"studio\"}\n                <!-- Mac Studio icon -->\n                <g\n                  transform=\"translate({-node.iconSize / 2}, {-node.iconSize /\n                    2})\"\n                >\n                  <rect\n                    x=\"2\"\n                    y=\"2\"\n                    width={node.iconSize - 4}\n                    height={node.iconSize - 4}\n                    rx=\"4\"\n                    fill=\"none\"\n                    stroke={node.isUsed ? \"#FFD700\" : \"#4B5563\"}\n                    stroke-width=\"1.5\"\n                  />\n                  <!-- Memory fill background -->\n                  <rect\n                    x=\"4\"\n                    y=\"4\"\n                    width={node.iconSize - 8}\n                    height={node.iconSize - 8}\n                    fill=\"#0a0a0a\"\n                  />\n                  <!-- Current memory fill -->\n                  <rect\n                    x=\"4\"\n                    y={4 +\n                      (node.iconSize - 8) * (1 - node.currentPercent / 100)}\n                    width={node.iconSize - 8}\n                    height={(node.iconSize - 8) * (node.currentPercent / 100)}\n                    fill=\"#374151\"\n                  />\n                  <!-- New model memory fill -->\n                  {#if node.modelUsageGB > 0 && node.isUsed}\n                    <rect\n                      x=\"4\"\n                      y={4 + (node.iconSize - 8) * (1 - node.newPercent / 100)}\n                      width={node.iconSize - 8}\n                      height={(node.iconSize - 8) *\n                        ((node.newPercent - node.currentPercent) / 100)}\n                      fill=\"#FFD700\"\n                      filter=\"url(#memGlow-{filterId})\"\n                      class=\"animate-pulse-slow\"\n                    />\n                  {/if}\n                </g>\n              {:else if node.deviceType === \"mini\"}\n                <!-- Mac Mini icon -->\n                <g\n                  transform=\"translate({-node.iconSize / 2}, {-node.iconSize /\n                    2})\"\n                >\n                  <rect\n                    x=\"2\"\n                    y={node.iconSize * 0.3}\n                    width={node.iconSize - 4}\n                    height={node.iconSize * 0.4}\n                    rx=\"3\"\n                    fill=\"none\"\n                    stroke={node.isUsed ? \"#FFD700\" : \"#4B5563\"}\n                    stroke-width=\"1.5\"\n                  />\n                  <!-- Memory fill background -->\n                  <rect\n                    x=\"4\"\n                    y={node.iconSize * 0.32}\n                    width={node.iconSize - 8}\n                    height={node.iconSize * 0.36}\n                    fill=\"#0a0a0a\"\n                  />\n                  <!-- Current memory fill -->\n                  <rect\n                    x=\"4\"\n                    y={node.iconSize * 0.32 +\n                      node.iconSize * 0.36 * (1 - node.currentPercent / 100)}\n                    width={node.iconSize - 8}\n                    height={node.iconSize * 0.36 * (node.currentPercent / 100)}\n                    fill=\"#374151\"\n                  />\n                  <!-- New model memory fill -->\n                  {#if node.modelUsageGB > 0 && node.isUsed}\n                    <rect\n                      x=\"4\"\n                      y={node.iconSize * 0.32 +\n                        node.iconSize * 0.36 * (1 - node.newPercent / 100)}\n                      width={node.iconSize - 8}\n                      height={node.iconSize *\n                        0.36 *\n                        ((node.newPercent - node.currentPercent) / 100)}\n                      fill=\"#FFD700\"\n                      filter=\"url(#memGlow-{filterId})\"\n                      class=\"animate-pulse-slow\"\n                    />\n                  {/if}\n                </g>\n              {:else}\n                <!-- Unknown device - hexagon -->\n                <g\n                  transform=\"translate({-node.iconSize / 2}, {-node.iconSize /\n                    2})\"\n                >\n                  <polygon\n                    points=\"{node.iconSize /\n                      2},0 {node.iconSize},{node.iconSize *\n                      0.25} {node.iconSize},{node.iconSize *\n                      0.75} {node.iconSize /\n                      2},{node.iconSize} 0,{node.iconSize *\n                      0.75} 0,{node.iconSize * 0.25}\"\n                    fill={node.isUsed ? \"rgba(255,215,0,0.1)\" : \"#0a0a0a\"}\n                    stroke={node.isUsed ? \"#FFD700\" : \"#4B5563\"}\n                    stroke-width=\"1.5\"\n                  />\n                </g>\n              {/if}\n\n              <!-- Percentage label -->\n              <text\n                y={node.iconSize / 2 + 12}\n                text-anchor=\"middle\"\n                font-size=\"8\"\n                font-family=\"SF Mono, Monaco, monospace\"\n                fill={node.isUsed\n                  ? node.newPercent > 90\n                    ? \"#f87171\"\n                    : \"#FFD700\"\n                  : \"#4B5563\"}\n              >\n                {node.newPercent.toFixed(0)}%\n              </text>\n            </g>\n          {/each}\n        </svg>\n      </div>\n    {/if}\n\n    <!-- Launch Button -->\n    <button\n      onclick={onLaunch}\n      disabled={isLaunching || !canFit}\n      class=\"w-full py-2 text-sm font-mono tracking-wider uppercase border transition-all duration-200\n\t\t\t\t{isLaunching\n        ? 'bg-transparent text-exo-yellow border-exo-yellow/50 cursor-wait'\n        : !canFit\n          ? 'bg-red-500/10 text-red-400/70 border-red-500/30 cursor-not-allowed'\n          : 'bg-transparent text-exo-light-gray border-exo-light-gray/40 hover:text-exo-yellow hover:border-exo-yellow/50 cursor-pointer'}\"\n    >\n      {#if isLaunching}\n        <span class=\"flex items-center justify-center gap-1.5\">\n          <span\n            class=\"w-2 h-2 border border-exo-yellow border-t-transparent rounded-full animate-spin\"\n          ></span>\n          LAUNCHING...\n        </span>\n      {:else if !canFit}\n        INSUFFICIENT MEMORY\n      {:else}\n        ▸ LAUNCH\n      {/if}\n    </button>\n  </div>\n</div>\n\n<style>\n  @keyframes pulse-slow {\n    0%,\n    100% {\n      opacity: 0.8;\n    }\n    50% {\n      opacity: 1;\n    }\n  }\n  .animate-pulse-slow {\n    animation: pulse-slow 1.5s ease-in-out infinite;\n  }\n</style>\n"
  },
  {
    "path": "dashboard/src/lib/components/ModelFilterPopover.svelte",
    "content": "<script lang=\"ts\">\n  import { fly } from \"svelte/transition\";\n  import { cubicOut } from \"svelte/easing\";\n\n  interface FilterState {\n    capabilities: string[];\n    sizeRange: { min: number; max: number } | null;\n    downloadedOnly: boolean;\n    readyOnly: boolean;\n  }\n\n  type ModelFilterPopoverProps = {\n    filters: FilterState;\n    onChange: (filters: FilterState) => void;\n    onClear: () => void;\n    onClose: () => void;\n  };\n\n  let { filters, onChange, onClear, onClose }: ModelFilterPopoverProps =\n    $props();\n\n  // Available capabilities\n  const availableCapabilities = [\n    { id: \"text\", label: \"Text\" },\n    { id: \"thinking\", label: \"Thinking\" },\n    { id: \"code\", label: \"Code\" },\n    { id: \"vision\", label: \"Vision\" },\n    { id: \"image_gen\", label: \"Image Gen\" },\n    { id: \"image_edit\", label: \"Image Edit\" },\n  ];\n\n  // Size ranges\n  const sizeRanges = [\n    { label: \"< 10GB\", min: 0, max: 10 },\n    { label: \"10-50GB\", min: 10, max: 50 },\n    { label: \"50-200GB\", min: 50, max: 200 },\n    { label: \"> 200GB\", min: 200, max: 10000 },\n  ];\n\n  function toggleCapability(cap: string) {\n    const next = filters.capabilities.includes(cap)\n      ? filters.capabilities.filter((c) => c !== cap)\n      : [...filters.capabilities, cap];\n    onChange({ ...filters, capabilities: next });\n  }\n\n  function selectSizeRange(range: { min: number; max: number } | null) {\n    // Toggle off if same range is clicked\n    if (\n      filters.sizeRange &&\n      range &&\n      filters.sizeRange.min === range.min &&\n      filters.sizeRange.max === range.max\n    ) {\n      onChange({ ...filters, sizeRange: null });\n    } else {\n      onChange({ ...filters, sizeRange: range });\n    }\n  }\n\n  function handleClickOutside(e: MouseEvent) {\n    const target = e.target as HTMLElement;\n    if (\n      !target.closest(\".filter-popover\") &&\n      !target.closest(\".filter-toggle\")\n    ) {\n      onClose();\n    }\n  }\n</script>\n\n<svelte:window onclick={handleClickOutside} />\n\n<!-- svelte-ignore a11y_no_static_element_interactions -->\n<div\n  class=\"filter-popover absolute right-0 top-full mt-2 w-64 bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-xl z-10\"\n  transition:fly={{ y: -10, duration: 200, easing: cubicOut }}\n  onclick={(e) => e.stopPropagation()}\n  role=\"dialog\"\n  aria-label=\"Filter options\"\n>\n  <div class=\"p-3 space-y-4\">\n    <!-- Capabilities -->\n    <div>\n      <h4 class=\"text-xs font-mono text-white/50 mb-2\">Capabilities</h4>\n      <div class=\"flex flex-wrap gap-1.5\">\n        {#each availableCapabilities as cap}\n          {@const isSelected = filters.capabilities.includes(cap.id)}\n          <button\n            type=\"button\"\n            class=\"px-2 py-1 text-xs font-mono rounded transition-colors {isSelected\n              ? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/30'\n              : 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}\"\n            onclick={() => toggleCapability(cap.id)}\n          >\n            {#if cap.id === \"text\"}\n              <svg\n                class=\"w-3.5 h-3.5 inline-block\"\n                viewBox=\"0 0 24 24\"\n                fill=\"none\"\n                stroke=\"currentColor\"\n                stroke-width=\"1.5\"\n                ><path\n                  d=\"M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z\"\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                /></svg\n              >\n            {:else if cap.id === \"thinking\"}\n              <svg\n                class=\"w-3.5 h-3.5 inline-block\"\n                viewBox=\"0 0 24 24\"\n                fill=\"none\"\n                stroke=\"currentColor\"\n                stroke-width=\"1.5\"\n                ><path\n                  d=\"M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4\"\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                /></svg\n              >\n            {:else if cap.id === \"code\"}\n              <svg\n                class=\"w-3.5 h-3.5 inline-block\"\n                viewBox=\"0 0 24 24\"\n                fill=\"none\"\n                stroke=\"currentColor\"\n                stroke-width=\"1.5\"\n                ><path\n                  d=\"M16 18l6-6-6-6M8 6l-6 6 6 6\"\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                /></svg\n              >\n            {:else if cap.id === \"vision\"}\n              <svg\n                class=\"w-3.5 h-3.5 inline-block\"\n                viewBox=\"0 0 24 24\"\n                fill=\"none\"\n                stroke=\"currentColor\"\n                stroke-width=\"1.5\"\n                ><path\n                  d=\"M1 12s4-8 11-8 11 8 11 8-4 8-11 8-11-8-11-8z\"\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                /><circle cx=\"12\" cy=\"12\" r=\"3\" /></svg\n              >\n            {:else if cap.id === \"image_gen\"}\n              <svg\n                class=\"w-3.5 h-3.5 inline-block\"\n                viewBox=\"0 0 24 24\"\n                fill=\"none\"\n                stroke=\"currentColor\"\n                stroke-width=\"1.5\"\n                ><rect\n                  x=\"3\"\n                  y=\"3\"\n                  width=\"18\"\n                  height=\"18\"\n                  rx=\"2\"\n                  ry=\"2\"\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                /><circle cx=\"8.5\" cy=\"8.5\" r=\"1.5\" /><path\n                  d=\"M21 15l-5-5L5 21\"\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                /></svg\n              >\n            {:else if cap.id === \"image_edit\"}\n              <svg\n                class=\"w-3.5 h-3.5 inline-block\"\n                viewBox=\"0 0 24 24\"\n                fill=\"none\"\n                stroke=\"currentColor\"\n                stroke-width=\"1.5\"\n                ><path\n                  d=\"M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7\"\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                /><path\n                  d=\"M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z\"\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                /></svg\n              >\n            {/if}\n            <span class=\"ml-1\">{cap.label}</span>\n          </button>\n        {/each}\n      </div>\n    </div>\n\n    <!-- Availability filters -->\n    <div>\n      <h4 class=\"text-xs font-mono text-white/50 mb-2\">Availability</h4>\n      <div class=\"flex flex-wrap gap-1.5\">\n        <button\n          type=\"button\"\n          class=\"px-2 py-1 text-xs font-mono rounded transition-colors {filters.downloadedOnly\n            ? 'bg-green-500/20 text-green-400 border border-green-500/30'\n            : 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}\"\n          onclick={() =>\n            onChange({ ...filters, downloadedOnly: !filters.downloadedOnly })}\n        >\n          <svg\n            class=\"w-3.5 h-3.5 inline-block\"\n            viewBox=\"0 0 24 24\"\n            fill=\"none\"\n            stroke=\"currentColor\"\n            stroke-width=\"2\"\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n          >\n            <path\n              class=\"text-white/40\"\n              d=\"M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z\"\n            />\n            <path class=\"text-green-400\" d=\"m9 13 2 2 4-4\" />\n          </svg>\n          <span class=\"ml-1\">Downloaded</span>\n        </button>\n        <button\n          type=\"button\"\n          class=\"px-2 py-1 text-xs font-mono rounded transition-colors {filters.readyOnly\n            ? 'bg-green-500/20 text-green-400 border border-green-500/30'\n            : 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}\"\n          onclick={() =>\n            onChange({ ...filters, readyOnly: !filters.readyOnly })}\n        >\n          <svg\n            class=\"w-3.5 h-3.5 inline-block\"\n            viewBox=\"0 0 24 24\"\n            fill=\"none\"\n            stroke=\"currentColor\"\n            stroke-width=\"2\"\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n          >\n            <circle cx=\"12\" cy=\"12\" r=\"10\" />\n            <path d=\"m9 12 2 2 4-4\" />\n          </svg>\n          <span class=\"ml-1\">Ready</span>\n        </button>\n      </div>\n    </div>\n\n    <!-- Size range -->\n    <div>\n      <h4 class=\"text-xs font-mono text-white/50 mb-2\">Model Size</h4>\n      <div class=\"flex flex-wrap gap-1.5\">\n        {#each sizeRanges as range}\n          {@const isSelected =\n            filters.sizeRange &&\n            filters.sizeRange.min === range.min &&\n            filters.sizeRange.max === range.max}\n          <button\n            type=\"button\"\n            class=\"px-2 py-1 text-xs font-mono rounded transition-colors {isSelected\n              ? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/30'\n              : 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}\"\n            onclick={() => selectSizeRange(range)}\n          >\n            {range.label}\n          </button>\n        {/each}\n      </div>\n    </div>\n\n    <!-- Clear button -->\n    <button\n      type=\"button\"\n      class=\"w-full py-1.5 text-xs font-mono text-white/50 hover:text-white/70 hover:bg-white/5 rounded transition-colors\"\n      onclick={onClear}\n    >\n      Clear all filters\n    </button>\n  </div>\n</div>\n"
  },
  {
    "path": "dashboard/src/lib/components/ModelPickerGroup.svelte",
    "content": "<script lang=\"ts\">\n  interface ModelInfo {\n    id: string;\n    name?: string;\n    storage_size_megabytes?: number;\n    base_model?: string;\n    quantization?: string;\n    supports_tensor?: boolean;\n    capabilities?: string[];\n    family?: string;\n    is_custom?: boolean;\n  }\n\n  interface ModelGroup {\n    id: string;\n    name: string;\n    capabilities: string[];\n    family: string;\n    variants: ModelInfo[];\n    smallestVariant: ModelInfo;\n    hasMultipleVariants: boolean;\n  }\n\n  type DownloadAvailability = {\n    available: boolean;\n    nodeNames: string[];\n    nodeIds: string[];\n  };\n  type ModelFitStatus = \"fits_now\" | \"fits_cluster_capacity\" | \"too_large\";\n\n  type ModelPickerGroupProps = {\n    group: ModelGroup;\n    isExpanded: boolean;\n    isFavorite: boolean;\n    isHighlighted?: boolean;\n    selectedModelId: string | null;\n    canModelFit: (id: string) => boolean;\n    getModelFitStatus: (id: string) => ModelFitStatus;\n    onToggleExpand: () => void;\n    onSelectModel: (modelId: string) => void;\n    onToggleFavorite: (baseModelId: string) => void;\n    onShowInfo: (group: ModelGroup) => void;\n    downloadStatusMap?: Map<string, DownloadAvailability>;\n    launchedAt?: number;\n    instanceStatuses?: Record<string, { status: string; statusClass: string }>;\n  };\n\n  let {\n    group,\n    isExpanded,\n    isFavorite,\n    isHighlighted = false,\n    selectedModelId,\n    canModelFit,\n    getModelFitStatus,\n    onToggleExpand,\n    onSelectModel,\n    onToggleFavorite,\n    onShowInfo,\n    downloadStatusMap,\n    launchedAt,\n    instanceStatuses = {},\n  }: ModelPickerGroupProps = $props();\n\n  // Group-level download status: show if any variant is downloaded\n  const groupDownloadStatus = $derived.by(() => {\n    if (!downloadStatusMap || downloadStatusMap.size === 0) return undefined;\n    // Return the first available entry (prefer \"available\" ones)\n    for (const avail of downloadStatusMap.values()) {\n      if (avail.available) return avail;\n    }\n    return downloadStatusMap.values().next().value;\n  });\n\n  // Format storage size\n  function formatSize(mb: number | undefined): string {\n    if (!mb) return \"\";\n    if (mb >= 1024) {\n      return `${(mb / 1024).toFixed(0)}GB`;\n    }\n    return `${mb}MB`;\n  }\n\n  function timeAgo(ts: number): string {\n    const seconds = Math.floor((Date.now() - ts) / 1000);\n    if (seconds < 60) return \"just now\";\n    const minutes = Math.floor(seconds / 60);\n    if (minutes < 60) return `${minutes}m ago`;\n    const hours = Math.floor(minutes / 60);\n    if (hours < 24) return `${hours}h ago`;\n    const days = Math.floor(hours / 24);\n    return `${days}d ago`;\n  }\n\n  // Check if any variant can fit\n  const anyVariantFits = $derived(\n    group.variants.some((v) => canModelFit(v.id)),\n  );\n  // Check if any variant has an active instance (ready, loading, downloading)\n  const anyVariantHasInstance = $derived(\n    instanceStatuses\n      ? group.variants.some((v) => instanceStatuses[v.id] != null)\n      : false,\n  );\n  const groupFitStatus = $derived.by((): ModelFitStatus => {\n    let hasClusterCapacityOnly = false;\n    for (const variant of group.variants) {\n      const fitStatus = getModelFitStatus(variant.id);\n      if (fitStatus === \"fits_now\") {\n        return \"fits_now\";\n      }\n      if (fitStatus === \"fits_cluster_capacity\") {\n        hasClusterCapacityOnly = true;\n      }\n    }\n    return hasClusterCapacityOnly ? \"fits_cluster_capacity\" : \"too_large\";\n  });\n\n  function getSizeClassForFitStatus(fitStatus: ModelFitStatus): string {\n    switch (fitStatus) {\n      case \"fits_now\":\n        return \"text-white/40\";\n      case \"fits_cluster_capacity\":\n        return \"text-orange-400/80\";\n      case \"too_large\":\n        return \"text-red-400/70\";\n    }\n  }\n\n  // Check if this group's model is currently selected (for single-variant groups)\n  const isMainSelected = $derived(\n    !group.hasMultipleVariants &&\n      group.variants.some((v) => v.id === selectedModelId),\n  );\n\n  // Group-level instance status: show the \"best\" status across all variants\n  const groupInstanceStatus = $derived.by(() => {\n    if (!instanceStatuses) return null;\n    const readyStatuses = [\"READY\", \"LOADED\", \"RUNNING\"];\n    const loadingStatuses = [\"LOADING\", \"WARMING UP\"];\n    let bestStatus: { status: string; statusClass: string } | null = null;\n    for (const variant of group.variants) {\n      const s = instanceStatuses[variant.id];\n      if (!s) continue;\n      if (readyStatuses.includes(s.status)) return s; // Ready is best\n      if (loadingStatuses.includes(s.status) || s.status === \"DOWNLOADING\") {\n        bestStatus = s;\n      }\n    }\n    return bestStatus;\n  });\n</script>\n\n<div\n  data-model-ids={group.variants.map((v) => v.id).join(\" \")}\n  class=\"border-b border-white/5 last:border-b-0 {!anyVariantFits &&\n  !anyVariantHasInstance\n    ? 'opacity-50'\n    : ''} {isHighlighted ? 'model-just-added' : ''}\"\n>\n  <!-- Main row -->\n  <div\n    class=\"flex items-center gap-2 px-3 py-2.5 transition-colors {anyVariantFits ||\n    anyVariantHasInstance\n      ? 'hover:bg-white/5 cursor-pointer'\n      : 'cursor-not-allowed'} {isMainSelected\n      ? 'bg-exo-yellow/10 border-l-2 border-exo-yellow'\n      : 'border-l-2 border-transparent'}\"\n    onclick={() => {\n      if (group.hasMultipleVariants) {\n        onToggleExpand();\n      } else {\n        const modelId = group.variants[0]?.id;\n        if (modelId && (canModelFit(modelId) || instanceStatuses[modelId])) {\n          onSelectModel(modelId);\n        }\n      }\n    }}\n    role=\"button\"\n    tabindex=\"0\"\n    onkeydown={(e) => {\n      if (e.key === \"Enter\" || e.key === \" \") {\n        e.preventDefault();\n        if (group.hasMultipleVariants) {\n          onToggleExpand();\n        } else {\n          const modelId = group.variants[0]?.id;\n          if (modelId && (canModelFit(modelId) || instanceStatuses[modelId])) {\n            onSelectModel(modelId);\n          }\n        }\n      }\n    }}\n  >\n    <!-- Expand/collapse chevron (for groups with variants) -->\n    {#if group.hasMultipleVariants}\n      <svg\n        class=\"w-4 h-4 text-white/40 transition-transform duration-200 flex-shrink-0 {isExpanded\n          ? 'rotate-90'\n          : ''}\"\n        viewBox=\"0 0 24 24\"\n        fill=\"currentColor\"\n      >\n        <path d=\"M8.59 16.59L13.17 12 8.59 7.41 10 6l6 6-6 6-1.41-1.41z\" />\n      </svg>\n    {:else}\n      <div class=\"w-4 flex-shrink-0\"></div>\n    {/if}\n\n    <!-- Model name -->\n    <div class=\"flex-1 min-w-0\">\n      <div class=\"flex items-center gap-2\">\n        <span class=\"font-mono text-sm text-white truncate\">\n          {group.name}\n        </span>\n        <!-- Capability icons -->\n        {#each group.capabilities.filter((c) => c !== \"text\") as cap}\n          {#if cap === \"thinking\"}\n            <svg\n              class=\"w-3.5 h-3.5 text-white/40 flex-shrink-0\"\n              viewBox=\"0 0 24 24\"\n              fill=\"none\"\n              stroke=\"currentColor\"\n              stroke-width=\"1.5\"\n              title=\"Supports Thinking\"\n            >\n              <path\n                d=\"M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4\"\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n              />\n            </svg>\n          {:else if cap === \"code\"}\n            <svg\n              class=\"w-3.5 h-3.5 text-white/40 flex-shrink-0\"\n              viewBox=\"0 0 24 24\"\n              fill=\"none\"\n              stroke=\"currentColor\"\n              stroke-width=\"1.5\"\n              title=\"Supports code generation\"\n            >\n              <path\n                d=\"M16 18l6-6-6-6M8 6l-6 6 6 6\"\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n              />\n            </svg>\n          {:else if cap === \"vision\"}\n            <svg\n              class=\"w-3.5 h-3.5 text-white/40 flex-shrink-0\"\n              viewBox=\"0 0 24 24\"\n              fill=\"none\"\n              stroke=\"currentColor\"\n              stroke-width=\"1.5\"\n              title=\"Supports image input\"\n            >\n              <path\n                d=\"M1 12s4-8 11-8 11 8 11 8-4 8-11 8-11-8-11-8z\"\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n              />\n              <circle cx=\"12\" cy=\"12\" r=\"3\" />\n            </svg>\n          {:else if cap === \"image_gen\"}\n            <svg\n              class=\"w-3.5 h-3.5 text-white/40 flex-shrink-0\"\n              viewBox=\"0 0 24 24\"\n              fill=\"none\"\n              stroke=\"currentColor\"\n              stroke-width=\"1.5\"\n              title=\"Supports image generation\"\n            >\n              <rect\n                x=\"3\"\n                y=\"3\"\n                width=\"18\"\n                height=\"18\"\n                rx=\"2\"\n                ry=\"2\"\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n              />\n              <circle cx=\"8.5\" cy=\"8.5\" r=\"1.5\" />\n              <path\n                d=\"M21 15l-5-5L5 21\"\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n              />\n            </svg>\n          {:else if cap === \"image_edit\"}\n            <svg\n              class=\"w-3.5 h-3.5 text-white/40 flex-shrink-0\"\n              viewBox=\"0 0 24 24\"\n              fill=\"none\"\n              stroke=\"currentColor\"\n              stroke-width=\"1.5\"\n              title=\"Supports image editing\"\n            >\n              <path\n                d=\"M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7\"\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n              />\n              <path\n                d=\"M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z\"\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n              />\n            </svg>\n          {/if}\n        {/each}\n      </div>\n    </div>\n\n    <!-- Size indicator (smallest variant) -->\n    {#if !group.hasMultipleVariants && group.smallestVariant?.storage_size_megabytes}\n      {@const singleVariantFitStatus = getModelFitStatus(\n        group.smallestVariant.id,\n      )}\n      <span\n        class=\"text-xs font-mono flex-shrink-0 {getSizeClassForFitStatus(\n          singleVariantFitStatus,\n        )}\"\n      >\n        {formatSize(group.smallestVariant.storage_size_megabytes)}\n      </span>\n    {/if}\n\n    <!-- Variant count with size range -->\n    {#if group.hasMultipleVariants}\n      {@const sizes = group.variants\n        .map((v) => v.storage_size_megabytes || 0)\n        .filter((s) => s > 0)\n        .sort((a, b) => a - b)}\n      <span\n        class=\"text-xs font-mono flex-shrink-0 {getSizeClassForFitStatus(\n          groupFitStatus,\n        )}\"\n      >\n        {group.variants.length} variants{#if sizes.length >= 2}{\" \"}({formatSize(\n            sizes[0],\n          )}-{formatSize(sizes[sizes.length - 1])}){/if}\n      </span>\n    {/if}\n\n    <!-- Time ago (for recent models) -->\n    {#if launchedAt}\n      <span class=\"text-xs font-mono text-white/20 flex-shrink-0\">\n        {timeAgo(launchedAt)}\n      </span>\n    {/if}\n\n    <!-- Download availability indicator -->\n    {#if groupDownloadStatus && groupDownloadStatus.nodeIds.length > 0}\n      <span\n        class=\"flex-shrink-0\"\n        title={groupDownloadStatus.available\n          ? `Ready — downloaded on ${groupDownloadStatus.nodeNames.join(\", \")}`\n          : `Downloaded on ${groupDownloadStatus.nodeNames.join(\", \")} (may need more nodes)`}\n      >\n        <svg\n          class=\"w-4 h-4\"\n          viewBox=\"0 0 24 24\"\n          fill=\"none\"\n          stroke=\"currentColor\"\n          stroke-width=\"2\"\n          stroke-linecap=\"round\"\n          stroke-linejoin=\"round\"\n        >\n          <path\n            class=\"text-white/40\"\n            d=\"M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z\"\n          />\n          <path class=\"text-green-400\" d=\"m9 13 2 2 4-4\" />\n        </svg>\n      </span>\n    {/if}\n\n    <!-- Instance status badge -->\n    {#if groupInstanceStatus}\n      {#if groupInstanceStatus.status === \"READY\" || groupInstanceStatus.status === \"LOADED\" || groupInstanceStatus.status === \"RUNNING\"}\n        <span class=\"flex-shrink-0\" title=\"Running\">\n          <svg\n            class=\"w-3 h-3 text-green-400\"\n            viewBox=\"0 0 12 12\"\n            fill=\"currentColor\"\n          >\n            <circle cx=\"6\" cy=\"6\" r=\"5\" />\n          </svg>\n        </span>\n      {:else if groupInstanceStatus.status === \"DOWNLOADING\"}\n        <span class=\"flex-shrink-0 animate-pulse\" title=\"Downloading\">\n          <svg\n            class=\"w-3.5 h-3.5 text-blue-400\"\n            viewBox=\"0 0 24 24\"\n            fill=\"none\"\n            stroke=\"currentColor\"\n            stroke-width=\"2\"\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n          >\n            <path d=\"M21 15v4a2 2 0 01-2 2H5a2 2 0 01-2-2v-4\" />\n            <polyline points=\"7 10 12 15 17 10\" />\n            <line x1=\"12\" y1=\"15\" x2=\"12\" y2=\"3\" />\n          </svg>\n        </span>\n      {:else if groupInstanceStatus.status === \"LOADING\" || groupInstanceStatus.status === \"WARMING UP\"}\n        <span class=\"flex-shrink-0 animate-pulse\" title=\"Loading\">\n          <svg\n            class=\"w-3 h-3 text-yellow-400\"\n            viewBox=\"0 0 12 12\"\n            fill=\"currentColor\"\n          >\n            <circle cx=\"6\" cy=\"6\" r=\"5\" />\n          </svg>\n        </span>\n      {/if}\n    {/if}\n\n    <!-- Check mark if selected (single-variant) -->\n    {#if isMainSelected}\n      <svg\n        class=\"w-4 h-4 text-exo-yellow flex-shrink-0\"\n        viewBox=\"0 0 24 24\"\n        fill=\"currentColor\"\n      >\n        <path d=\"M9 16.17L4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41L9 16.17z\" />\n      </svg>\n    {/if}\n\n    <!-- Favorite star -->\n    <button\n      type=\"button\"\n      class=\"p-1 rounded hover:bg-white/10 transition-colors flex-shrink-0\"\n      onclick={(e) => {\n        e.stopPropagation();\n        onToggleFavorite(group.id);\n      }}\n      title={isFavorite ? \"Remove from favorites\" : \"Add to favorites\"}\n    >\n      {#if isFavorite}\n        <svg\n          class=\"w-4 h-4 text-amber-400\"\n          viewBox=\"0 0 24 24\"\n          fill=\"currentColor\"\n        >\n          <path\n            d=\"M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z\"\n          />\n        </svg>\n      {:else}\n        <svg\n          class=\"w-4 h-4 text-white/30 hover:text-white/50\"\n          viewBox=\"0 0 24 24\"\n          fill=\"none\"\n          stroke=\"currentColor\"\n          stroke-width=\"2\"\n        >\n          <path\n            d=\"M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z\"\n          />\n        </svg>\n      {/if}\n    </button>\n\n    <!-- Info button -->\n    <button\n      type=\"button\"\n      class=\"p-1 rounded hover:bg-white/10 transition-colors flex-shrink-0\"\n      onclick={(e) => {\n        e.stopPropagation();\n        onShowInfo(group);\n      }}\n      title=\"View model details\"\n    >\n      <svg\n        class=\"w-4 h-4 text-white/30 hover:text-white/50\"\n        viewBox=\"0 0 24 24\"\n        fill=\"currentColor\"\n      >\n        <path\n          d=\"M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm1 15h-2v-6h2v6zm0-8h-2V7h2v2z\"\n        />\n      </svg>\n    </button>\n  </div>\n\n  <!-- Expanded variants -->\n  {#if isExpanded && group.hasMultipleVariants}\n    <div class=\"bg-black/20 border-t border-white/5\">\n      {#each group.variants as variant}\n        {@const fitStatus = getModelFitStatus(variant.id)}\n        {@const modelCanFit = canModelFit(variant.id)}\n        {@const variantHasInstance = instanceStatuses[variant.id] != null}\n        {@const isSelected = selectedModelId === variant.id}\n        <div\n          class=\"w-full flex items-center gap-3 px-3 py-2 pl-10 hover:bg-white/5 transition-colors text-left {!modelCanFit &&\n          !variantHasInstance\n            ? 'opacity-50 cursor-not-allowed'\n            : 'cursor-pointer'} {isSelected\n            ? 'bg-exo-yellow/10 border-l-2 border-exo-yellow'\n            : 'border-l-2 border-transparent'}\"\n          role=\"button\"\n          tabindex=\"0\"\n          onclick={() => {\n            if (modelCanFit || variantHasInstance) {\n              onSelectModel(variant.id);\n            }\n          }}\n          onkeydown={(e) => {\n            if (e.key === \"Enter\" || e.key === \" \") {\n              e.preventDefault();\n              if (modelCanFit) {\n                onSelectModel(variant.id);\n              }\n            }\n          }}\n        >\n          <!-- Quantization badge -->\n          <span\n            class=\"text-xs font-mono px-1.5 py-0.5 rounded bg-white/10 text-white/70 flex-shrink-0\"\n          >\n            {variant.quantization || \"default\"}\n          </span>\n\n          <!-- Size -->\n          <span\n            class=\"text-xs font-mono flex-1 {getSizeClassForFitStatus(\n              fitStatus,\n            )}\"\n          >\n            {formatSize(variant.storage_size_megabytes)}\n          </span>\n\n          <!-- Download indicator for this variant -->\n          {#if downloadStatusMap?.get(variant.id)}\n            {@const variantDl = downloadStatusMap.get(variant.id)}\n            {#if variantDl}\n              <span\n                class=\"flex-shrink-0\"\n                title={`Downloaded on ${variantDl.nodeNames.join(\", \")}`}\n              >\n                <svg\n                  class=\"w-3.5 h-3.5\"\n                  viewBox=\"0 0 24 24\"\n                  fill=\"none\"\n                  stroke=\"currentColor\"\n                  stroke-width=\"2\"\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                >\n                  <path\n                    class=\"text-white/40\"\n                    d=\"M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z\"\n                  />\n                  <path class=\"text-green-400\" d=\"m9 13 2 2 4-4\" />\n                </svg>\n              </span>\n            {/if}\n          {/if}\n\n          <!-- Instance status badge -->\n          {#if instanceStatuses[variant.id]}\n            {@const instStatus = instanceStatuses[variant.id]}\n            {#if instStatus.status === \"READY\" || instStatus.status === \"LOADED\" || instStatus.status === \"RUNNING\"}\n              <span class=\"flex-shrink-0\" title=\"Running\">\n                <svg\n                  class=\"w-3 h-3 text-green-400\"\n                  viewBox=\"0 0 12 12\"\n                  fill=\"currentColor\"\n                >\n                  <circle cx=\"6\" cy=\"6\" r=\"5\" />\n                </svg>\n              </span>\n            {:else if instStatus.status === \"DOWNLOADING\"}\n              <span class=\"flex-shrink-0 animate-pulse\" title=\"Downloading\">\n                <svg\n                  class=\"w-3.5 h-3.5 text-blue-400\"\n                  viewBox=\"0 0 24 24\"\n                  fill=\"none\"\n                  stroke=\"currentColor\"\n                  stroke-width=\"2\"\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                >\n                  <path d=\"M21 15v4a2 2 0 01-2 2H5a2 2 0 01-2-2v-4\" />\n                  <polyline points=\"7 10 12 15 17 10\" />\n                  <line x1=\"12\" y1=\"15\" x2=\"12\" y2=\"3\" />\n                </svg>\n              </span>\n            {:else if instStatus.status === \"LOADING\" || instStatus.status === \"WARMING UP\"}\n              <span class=\"flex-shrink-0 animate-pulse\" title=\"Loading\">\n                <svg\n                  class=\"w-3 h-3 text-yellow-400\"\n                  viewBox=\"0 0 12 12\"\n                  fill=\"currentColor\"\n                >\n                  <circle cx=\"6\" cy=\"6\" r=\"5\" />\n                </svg>\n              </span>\n            {/if}\n          {/if}\n\n          <!-- Check mark if selected -->\n          {#if isSelected}\n            <svg\n              class=\"w-4 h-4 text-exo-yellow\"\n              viewBox=\"0 0 24 24\"\n              fill=\"currentColor\"\n            >\n              <path\n                d=\"M9 16.17L4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41L9 16.17z\"\n              />\n            </svg>\n          {/if}\n\n          <!-- Info button -->\n          <button\n            type=\"button\"\n            class=\"p-1 rounded hover:bg-white/10 transition-colors flex-shrink-0\"\n            onclick={(e) => {\n              e.stopPropagation();\n              onShowInfo({\n                id: variant.id,\n                name: variant.name || variant.id,\n                capabilities: group.capabilities,\n                family: group.family,\n                variants: [variant],\n                smallestVariant: variant,\n                hasMultipleVariants: false,\n              });\n            }}\n            title=\"View variant details\"\n          >\n            <svg\n              class=\"w-4 h-4 text-white/30 hover:text-white/50\"\n              viewBox=\"0 0 24 24\"\n              fill=\"currentColor\"\n            >\n              <path\n                d=\"M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm1 15h-2v-6h2v6zm0-8h-2V7h2v2z\"\n              />\n            </svg>\n          </button>\n        </div>\n      {/each}\n    </div>\n  {/if}\n</div>\n\n<style>\n  .model-just-added {\n    animation: highlightFade 4s ease-out forwards;\n  }\n\n  @keyframes highlightFade {\n    0%,\n    40% {\n      background-color: rgba(20, 83, 45, 0.25);\n      box-shadow: inset 0 0 0 1px rgba(74, 222, 128, 0.4);\n    }\n    100% {\n      background-color: transparent;\n      box-shadow: none;\n    }\n  }\n</style>\n"
  },
  {
    "path": "dashboard/src/lib/components/ModelPickerModal.svelte",
    "content": "<script lang=\"ts\">\n  import { tick } from \"svelte\";\n  import { fade, fly } from \"svelte/transition\";\n  import { cubicOut } from \"svelte/easing\";\n  import FamilySidebar from \"./FamilySidebar.svelte\";\n  import ModelPickerGroup from \"./ModelPickerGroup.svelte\";\n  import ModelFilterPopover from \"./ModelFilterPopover.svelte\";\n  import HuggingFaceResultItem from \"./HuggingFaceResultItem.svelte\";\n  import { getNodesWithModelDownloaded } from \"$lib/utils/downloads\";\n  import { getRecentEntries } from \"$lib/stores/recents.svelte\";\n  import { addToast } from \"$lib/stores/toast.svelte\";\n\n  interface ModelInfo {\n    id: string;\n    name?: string;\n    storage_size_megabytes?: number;\n    base_model?: string;\n    quantization?: string;\n    supports_tensor?: boolean;\n    capabilities?: string[];\n    family?: string;\n    is_custom?: boolean;\n    tasks?: string[];\n    hugging_face_id?: string;\n  }\n\n  interface ModelGroup {\n    id: string;\n    name: string;\n    capabilities: string[];\n    family: string;\n    variants: ModelInfo[];\n    smallestVariant: ModelInfo;\n    hasMultipleVariants: boolean;\n  }\n\n  interface FilterState {\n    capabilities: string[];\n    sizeRange: { min: number; max: number } | null;\n    downloadedOnly: boolean;\n    readyOnly: boolean;\n  }\n\n  interface HuggingFaceModel {\n    id: string;\n    author: string;\n    downloads: number;\n    likes: number;\n    last_modified: string;\n    tags: string[];\n  }\n\n  type ModelFitStatus = \"fits_now\" | \"fits_cluster_capacity\" | \"too_large\";\n\n  export type InstanceStatus = {\n    status: string;\n    statusClass: string;\n  };\n\n  type ModelPickerModalProps = {\n    isOpen: boolean;\n    models: ModelInfo[];\n    selectedModelId: string | null;\n    favorites: Set<string>;\n    recentModelIds?: string[];\n    hasRecents?: boolean;\n    existingModelIds: Set<string>;\n    canModelFit: (modelId: string) => boolean;\n    getModelFitStatus: (modelId: string) => ModelFitStatus;\n    onSelect: (modelId: string) => void;\n    onClose: () => void;\n    onToggleFavorite: (baseModelId: string) => void;\n    onAddModel: (modelId: string) => Promise<void>;\n    onDeleteModel: (modelId: string) => Promise<void>;\n    totalMemoryGB: number;\n    usedMemoryGB: number;\n    downloadsData?: Record<string, unknown[]>;\n    topologyNodes?: Record<\n      string,\n      {\n        friendly_name?: string;\n        system_info?: { model_id?: string };\n        macmon_info?: { memory?: { ram_total?: number } };\n      }\n    >;\n    instanceStatuses?: Record<string, InstanceStatus>;\n  };\n\n  let {\n    isOpen,\n    models,\n    selectedModelId,\n    favorites,\n    recentModelIds = [],\n    hasRecents: hasRecentsTab = false,\n    existingModelIds,\n    canModelFit,\n    getModelFitStatus,\n    onSelect,\n    onClose,\n    onToggleFavorite,\n    onAddModel,\n    onDeleteModel,\n    totalMemoryGB,\n    usedMemoryGB,\n    downloadsData,\n    topologyNodes,\n    instanceStatuses = {},\n  }: ModelPickerModalProps = $props();\n\n  // Local state\n  let searchQuery = $state(\"\");\n  let selectedFamily = $state<string | null>(null);\n  let expandedGroups = $state<Set<string>>(new Set());\n  let showFilters = $state(false);\n  let filters = $state<FilterState>({\n    capabilities: [],\n    sizeRange: null,\n    downloadedOnly: false,\n    readyOnly: false,\n  });\n  let infoGroup = $state<ModelGroup | null>(null);\n\n  // Download availability per model group\n  type DownloadAvailability = {\n    available: boolean;\n    nodeNames: string[];\n    nodeIds: string[];\n  };\n\n  function getNodeName(nodeId: string): string {\n    const node = topologyNodes?.[nodeId];\n    return (\n      node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8)\n    );\n  }\n\n  const modelDownloadAvailability = $derived.by(() => {\n    const result = new Map<string, DownloadAvailability>();\n    if (!downloadsData || !topologyNodes) return result;\n\n    for (const model of models) {\n      const nodeIds = getNodesWithModelDownloaded(downloadsData, model.id);\n      if (nodeIds.length === 0) continue;\n\n      // Sum total RAM across nodes that have the model\n      let totalRamBytes = 0;\n      for (const nodeId of nodeIds) {\n        const ramTotal = topologyNodes[nodeId]?.macmon_info?.memory?.ram_total;\n        if (typeof ramTotal === \"number\") totalRamBytes += ramTotal;\n      }\n\n      const modelSizeBytes = (model.storage_size_megabytes || 0) * 1024 * 1024;\n      result.set(model.id, {\n        available: modelSizeBytes > 0 && totalRamBytes >= modelSizeBytes,\n        nodeNames: nodeIds.map(getNodeName),\n        nodeIds,\n      });\n    }\n    return result;\n  });\n\n  // Aggregate download availability per group (available if ANY variant is available)\n  function getGroupDownloadAvailability(\n    group: ModelGroup,\n  ): DownloadAvailability | undefined {\n    for (const variant of group.variants) {\n      const avail = modelDownloadAvailability.get(variant.id);\n      if (avail && avail.nodeIds.length > 0) return avail;\n    }\n    return undefined;\n  }\n\n  // Get per-variant download map for a group\n  function getVariantDownloadMap(\n    group: ModelGroup,\n  ): Map<string, DownloadAvailability> {\n    const map = new Map<string, DownloadAvailability>();\n    for (const variant of group.variants) {\n      const avail = modelDownloadAvailability.get(variant.id);\n      if (avail && avail.nodeIds.length > 0) map.set(variant.id, avail);\n    }\n    return map;\n  }\n\n  // HuggingFace Hub state\n  let hfSearchQuery = $state(\"\");\n  let hfSearchResults = $state<HuggingFaceModel[]>([]);\n  let hfTrendingModels = $state<HuggingFaceModel[]>([]);\n  let hfIsSearching = $state(false);\n  let hfIsLoadingTrending = $state(false);\n  let addingModelId = $state<string | null>(null);\n  let hfSearchDebounceTimer: ReturnType<typeof setTimeout> | null = null;\n  let manualModelId = $state(\"\");\n  let addModelError = $state<string | null>(null);\n  let justAddedModelId = $state<string | null>(null);\n  let justAddedTimer: ReturnType<typeof setTimeout> | null = null;\n\n  // Inline HuggingFace search in main search bar\n  let mainSearchHfResults = $state<HuggingFaceModel[]>([]);\n  let mainSearchHfLoading = $state(false);\n  let mainSearchDebounceTimer: ReturnType<typeof setTimeout> | null = null;\n\n  // Reset transient state when modal opens, but preserve tab selection\n  $effect(() => {\n    if (isOpen) {\n      searchQuery = \"\";\n      expandedGroups = new Set();\n      showFilters = false;\n      manualModelId = \"\";\n      addModelError = null;\n      justAddedModelId = null;\n      if (justAddedTimer) {\n        clearTimeout(justAddedTimer);\n        justAddedTimer = null;\n      }\n    }\n  });\n\n  // Fetch trending models when HuggingFace is selected\n  $effect(() => {\n    if (\n      selectedFamily === \"huggingface\" &&\n      hfTrendingModels.length === 0 &&\n      !hfIsLoadingTrending\n    ) {\n      fetchTrendingModels();\n    }\n  });\n\n  // Inline HuggingFace search when local search returns no results\n  $effect(() => {\n    const query = searchQuery.trim();\n    const noLocalResults = filteredGroups.length === 0;\n\n    if (mainSearchDebounceTimer) {\n      clearTimeout(mainSearchDebounceTimer);\n      mainSearchDebounceTimer = null;\n    }\n\n    if (\n      selectedFamily === \"huggingface\" ||\n      selectedFamily === \"recents\" ||\n      selectedFamily === \"favorites\" ||\n      query.length < 2 ||\n      !noLocalResults\n    ) {\n      mainSearchHfResults = [];\n      mainSearchHfLoading = false;\n      return;\n    }\n\n    mainSearchHfLoading = true;\n    mainSearchDebounceTimer = setTimeout(async () => {\n      try {\n        const response = await fetch(\n          `/models/search?query=${encodeURIComponent(query)}&limit=10`,\n        );\n        if (response.ok) {\n          const results: HuggingFaceModel[] = await response.json();\n          mainSearchHfResults = results.filter(\n            (r) => !existingModelIds.has(r.id),\n          );\n        } else {\n          mainSearchHfResults = [];\n        }\n      } catch {\n        mainSearchHfResults = [];\n      } finally {\n        mainSearchHfLoading = false;\n      }\n    }, 500);\n  });\n\n  async function fetchTrendingModels() {\n    hfIsLoadingTrending = true;\n    try {\n      const response = await fetch(\"/models/search?query=&limit=20\");\n      if (response.ok) {\n        hfTrendingModels = await response.json();\n      }\n    } catch (error) {\n      console.error(\"Failed to fetch trending models:\", error);\n    } finally {\n      hfIsLoadingTrending = false;\n    }\n  }\n\n  async function searchHuggingFace(query: string) {\n    if (query.length < 2) {\n      hfSearchResults = [];\n      return;\n    }\n\n    hfIsSearching = true;\n    try {\n      const response = await fetch(\n        `/models/search?query=${encodeURIComponent(query)}&limit=20`,\n      );\n      if (response.ok) {\n        hfSearchResults = await response.json();\n      } else {\n        hfSearchResults = [];\n      }\n    } catch (error) {\n      console.error(\"Failed to search models:\", error);\n      hfSearchResults = [];\n    } finally {\n      hfIsSearching = false;\n    }\n  }\n\n  function handleHfSearchInput(query: string) {\n    hfSearchQuery = query;\n    addModelError = null;\n\n    if (hfSearchDebounceTimer) {\n      clearTimeout(hfSearchDebounceTimer);\n    }\n\n    if (query.length >= 2) {\n      hfSearchDebounceTimer = setTimeout(() => {\n        searchHuggingFace(query);\n      }, 300);\n    } else {\n      hfSearchResults = [];\n    }\n  }\n\n  async function handleAddModel(modelId: string) {\n    addingModelId = modelId;\n    addModelError = null;\n    try {\n      await onAddModel(modelId);\n      // Success: show toast, switch to All Models, highlight the model\n      const shortName = modelId.split(\"/\").pop() || modelId;\n      addToast({ type: \"success\", message: `Added ${shortName}` });\n      justAddedModelId = modelId;\n      selectedFamily = null;\n      searchQuery = \"\";\n      // Scroll to the newly added model after DOM update\n      await tick();\n      const el = document.querySelector(\n        `[data-model-ids~=\"${CSS.escape(modelId)}\"]`,\n      );\n      el?.scrollIntoView({ behavior: \"smooth\", block: \"center\" });\n      // Clear highlight after 4 seconds\n      if (justAddedTimer) clearTimeout(justAddedTimer);\n      justAddedTimer = setTimeout(() => {\n        justAddedModelId = null;\n        justAddedTimer = null;\n      }, 4000);\n    } catch (error) {\n      addModelError =\n        error instanceof Error ? error.message : \"Failed to add model\";\n    } finally {\n      addingModelId = null;\n    }\n  }\n\n  async function handleAddManualModel() {\n    if (!manualModelId.trim()) return;\n    await handleAddModel(manualModelId.trim());\n    if (!addModelError) {\n      manualModelId = \"\";\n    }\n  }\n\n  function handleSelectHfModel(modelId: string) {\n    onSelect(modelId);\n    onClose();\n  }\n\n  // Models to display in HuggingFace view\n  const hfDisplayModels = $derived.by((): HuggingFaceModel[] => {\n    if (hfSearchQuery.length >= 2) {\n      return hfSearchResults;\n    }\n    return hfTrendingModels;\n  });\n\n  // Group models by base_model\n  const groupedModels = $derived.by((): ModelGroup[] => {\n    const groups = new Map<string, ModelGroup>();\n\n    for (const model of models) {\n      const groupId = model.base_model || model.id;\n      const groupName = model.base_model || model.name || model.id;\n\n      if (!groups.has(groupId)) {\n        groups.set(groupId, {\n          id: groupId,\n          name: groupName,\n          capabilities: model.capabilities || [\"text\"],\n          family: model.family || \"\",\n          variants: [],\n          smallestVariant: model,\n          hasMultipleVariants: false,\n        });\n      }\n\n      const group = groups.get(groupId)!;\n      group.variants.push(model);\n\n      // Track smallest variant\n      if (\n        (model.storage_size_megabytes || 0) <\n        (group.smallestVariant.storage_size_megabytes || Infinity)\n      ) {\n        group.smallestVariant = model;\n      }\n\n      // Update capabilities if not set\n      if (\n        group.capabilities.length <= 1 &&\n        model.capabilities &&\n        model.capabilities.length > 1\n      ) {\n        group.capabilities = model.capabilities;\n      }\n      if (!group.family && model.family) {\n        group.family = model.family;\n      }\n    }\n\n    // Sort variants within each group by size\n    for (const group of groups.values()) {\n      group.variants.sort(\n        (a, b) =>\n          (a.storage_size_megabytes || 0) - (b.storage_size_megabytes || 0),\n      );\n      group.hasMultipleVariants = group.variants.length > 1;\n    }\n\n    // Convert to array and sort by smallest variant size (biggest first)\n    return Array.from(groups.values()).sort((a, b) => {\n      return (\n        (b.smallestVariant.storage_size_megabytes || 0) -\n        (a.smallestVariant.storage_size_megabytes || 0)\n      );\n    });\n  });\n\n  // Get unique families\n  const uniqueFamilies = $derived.by((): string[] => {\n    const families = new Set<string>();\n    for (const group of groupedModels) {\n      if (group.family) {\n        families.add(group.family);\n      }\n    }\n    const familyOrder = [\n      \"kimi\",\n      \"qwen\",\n      \"glm\",\n      \"minimax\",\n      \"deepseek\",\n      \"gpt-oss\",\n      \"llama\",\n      \"flux\",\n      \"qwen-image\",\n    ];\n    return Array.from(families).sort((a, b) => {\n      const aIdx = familyOrder.indexOf(a);\n      const bIdx = familyOrder.indexOf(b);\n      if (aIdx === -1 && bIdx === -1) return a.localeCompare(b);\n      if (aIdx === -1) return 1;\n      if (bIdx === -1) return -1;\n      return aIdx - bIdx;\n    });\n  });\n\n  // Filter models based on search, family, and filters\n  const filteredGroups = $derived.by((): ModelGroup[] => {\n    let result: ModelGroup[] = [...groupedModels];\n\n    // Filter by family\n    if (selectedFamily === \"favorites\") {\n      result = result.filter((g) => favorites.has(g.id));\n    } else if (\n      selectedFamily &&\n      selectedFamily !== \"huggingface\" &&\n      selectedFamily !== \"recents\"\n    ) {\n      result = result.filter((g) => g.family === selectedFamily);\n    }\n\n    // Filter by search query\n    if (searchQuery.trim()) {\n      const query = searchQuery.toLowerCase().trim();\n      result = result.filter(\n        (g) =>\n          g.name.toLowerCase().includes(query) ||\n          g.variants.some(\n            (v) =>\n              v.id.toLowerCase().includes(query) ||\n              (v.name || \"\").toLowerCase().includes(query),\n          ),\n      );\n    }\n\n    // Filter by capabilities\n    if (filters.capabilities.length > 0) {\n      result = result.filter((g) =>\n        filters.capabilities.every((cap) => g.capabilities.includes(cap)),\n      );\n    }\n\n    // Filter by size range\n    if (filters.sizeRange) {\n      const { min, max } = filters.sizeRange;\n      result = result.filter((g) => {\n        const sizeGB = (g.smallestVariant.storage_size_megabytes || 0) / 1024;\n        return sizeGB >= min && sizeGB <= max;\n      });\n    }\n\n    // Filter to downloaded models only\n    if (filters.downloadedOnly) {\n      result = result.filter((g) =>\n        g.variants.some((v) => {\n          const avail = modelDownloadAvailability.get(v.id);\n          return avail && avail.nodeIds.length > 0;\n        }),\n      );\n    }\n\n    // Filter to ready/running models only\n    if (filters.readyOnly) {\n      result = result.filter((g) =>\n        g.variants.some((v) => {\n          const s = instanceStatuses[v.id];\n          return s && s.statusClass === \"ready\";\n        }),\n      );\n    }\n\n    // Sort: fits-now first, then fits-cluster-capacity, then too-large\n    result.sort((a, b) => {\n      const getGroupFitRank = (group: ModelGroup): number => {\n        let hasClusterCapacityOnly = false;\n        for (const variant of group.variants) {\n          const fitStatus = getModelFitStatus(variant.id);\n          if (fitStatus === \"fits_now\") return 0;\n          if (fitStatus === \"fits_cluster_capacity\") {\n            hasClusterCapacityOnly = true;\n          }\n        }\n        return hasClusterCapacityOnly ? 1 : 2;\n      };\n\n      const aRank = getGroupFitRank(a);\n      const bRank = getGroupFitRank(b);\n      if (aRank !== bRank) return aRank - bRank;\n\n      return (\n        (b.smallestVariant.storage_size_megabytes || 0) -\n        (a.smallestVariant.storage_size_megabytes || 0)\n      );\n    });\n\n    return result;\n  });\n\n  // Check if any favorites exist\n  const hasFavorites = $derived(favorites.size > 0);\n\n  // Timestamp lookup for recent models\n  const recentTimestamps = $derived(\n    new Map(getRecentEntries().map((e) => [e.modelId, e.launchedAt])),\n  );\n\n  // Recent models: single-variant ModelGroups in launch order\n  const recentGroups = $derived.by((): ModelGroup[] => {\n    if (!recentModelIds || recentModelIds.length === 0) return [];\n    const result: ModelGroup[] = [];\n    for (const id of recentModelIds) {\n      const model = models.find((m) => m.id === id);\n      if (model) {\n        result.push({\n          id: model.base_model || model.id,\n          name: model.name || model.id,\n          capabilities: model.capabilities || [\"text\"],\n          family: model.family || \"\",\n          variants: [model],\n          smallestVariant: model,\n          hasMultipleVariants: false,\n        });\n      }\n    }\n    return result;\n  });\n\n  // Filtered recent groups (apply search query)\n  const filteredRecentGroups = $derived.by((): ModelGroup[] => {\n    if (!searchQuery.trim()) return recentGroups;\n    const query = searchQuery.toLowerCase().trim();\n    return recentGroups.filter(\n      (g) =>\n        g.name.toLowerCase().includes(query) ||\n        g.variants.some(\n          (v) =>\n            v.id.toLowerCase().includes(query) ||\n            (v.name || \"\").toLowerCase().includes(query) ||\n            (v.quantization || \"\").toLowerCase().includes(query),\n        ),\n    );\n  });\n\n  // Split filtered groups into recommended (fits_now) and others for visual separation\n  const recommendedGroups = $derived(\n    filteredGroups.filter((g) =>\n      g.variants.some((v) => getModelFitStatus(v.id) === \"fits_now\"),\n    ),\n  );\n  const otherGroups = $derived(\n    filteredGroups.filter(\n      (g) => !g.variants.some((v) => getModelFitStatus(v.id) === \"fits_now\"),\n    ),\n  );\n\n  function toggleGroupExpanded(groupId: string) {\n    const next = new Set(expandedGroups);\n    if (next.has(groupId)) {\n      next.delete(groupId);\n    } else {\n      next.add(groupId);\n    }\n    expandedGroups = next;\n  }\n\n  function handleSelect(modelId: string) {\n    onSelect(modelId);\n    onClose();\n  }\n\n  function handleKeydown(e: KeyboardEvent) {\n    if (e.key === \"Escape\") {\n      onClose();\n    }\n  }\n\n  function handleFiltersChange(newFilters: FilterState) {\n    filters = newFilters;\n  }\n\n  function clearFilters() {\n    filters = {\n      capabilities: [],\n      sizeRange: null,\n      downloadedOnly: false,\n      readyOnly: false,\n    };\n  }\n\n  const hasActiveFilters = $derived(\n    filters.capabilities.length > 0 ||\n      filters.sizeRange !== null ||\n      filters.downloadedOnly ||\n      filters.readyOnly,\n  );\n</script>\n\n<svelte:window onkeydown={handleKeydown} />\n\n{#if isOpen}\n  <!-- Backdrop -->\n  <div\n    class=\"fixed inset-0 z-50 bg-black/80 backdrop-blur-sm\"\n    transition:fade={{ duration: 200 }}\n    onclick={onClose}\n    role=\"presentation\"\n  ></div>\n\n  <!-- Modal -->\n  <div\n    class=\"fixed z-50 top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(90vw,600px)] h-[min(80vh,700px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl overflow-hidden flex flex-col\"\n    transition:fly={{ y: 20, duration: 300, easing: cubicOut }}\n    role=\"dialog\"\n    aria-modal=\"true\"\n    aria-label=\"Select a model\"\n  >\n    <!-- Header with search -->\n    <div\n      class=\"flex items-center gap-2 p-3 border-b border-exo-yellow/10 bg-exo-medium-gray/30\"\n    >\n      {#if selectedFamily === \"huggingface\"}\n        <!-- HuggingFace search -->\n        <svg\n          class=\"w-5 h-5 text-orange-400/60 flex-shrink-0\"\n          viewBox=\"0 0 24 24\"\n          fill=\"none\"\n          stroke=\"currentColor\"\n          stroke-width=\"2\"\n        >\n          <circle cx=\"11\" cy=\"11\" r=\"8\" />\n          <path d=\"M21 21l-4.35-4.35\" />\n        </svg>\n        <input\n          type=\"search\"\n          class=\"flex-1 bg-transparent border-none outline-none text-sm font-mono text-white placeholder-white/40\"\n          placeholder=\"Search mlx-community models...\"\n          value={hfSearchQuery}\n          oninput={(e) => handleHfSearchInput(e.currentTarget.value)}\n        />\n        {#if hfIsSearching}\n          <div class=\"flex-shrink-0\">\n            <span\n              class=\"w-4 h-4 border-2 border-orange-400 border-t-transparent rounded-full animate-spin block\"\n            ></span>\n          </div>\n        {/if}\n      {:else}\n        <!-- Normal model search -->\n        <svg\n          class=\"w-5 h-5 text-white/40 flex-shrink-0\"\n          viewBox=\"0 0 24 24\"\n          fill=\"none\"\n          stroke=\"currentColor\"\n          stroke-width=\"2\"\n        >\n          <circle cx=\"11\" cy=\"11\" r=\"8\" />\n          <path d=\"M21 21l-4.35-4.35\" />\n        </svg>\n        <input\n          type=\"search\"\n          class=\"flex-1 bg-transparent border-none outline-none text-sm font-mono text-white placeholder-white/40\"\n          placeholder=\"Search models...\"\n          bind:value={searchQuery}\n        />\n        <!-- Cluster memory -->\n        <span\n          class=\"text-xs font-mono flex-shrink-0\"\n          title=\"Cluster memory usage\"\n          ><span class=\"text-exo-yellow\">{Math.round(usedMemoryGB)}GB</span\n          ><span class=\"text-white/40\">/{Math.round(totalMemoryGB)}GB</span\n          ></span\n        >\n        <!-- Filter button -->\n        <div class=\"relative filter-toggle\">\n          <button\n            type=\"button\"\n            class=\"p-1.5 rounded hover:bg-white/10 transition-colors {hasActiveFilters\n              ? 'text-exo-yellow'\n              : 'text-white/50'}\"\n            onclick={() => (showFilters = !showFilters)}\n            title=\"Filter by capability or size\"\n          >\n            <svg class=\"w-5 h-5\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n              <path d=\"M10 18h4v-2h-4v2zM3 6v2h18V6H3zm3 7h12v-2H6v2z\" />\n            </svg>\n          </button>\n          {#if showFilters}\n            <ModelFilterPopover\n              {filters}\n              onChange={handleFiltersChange}\n              onClear={clearFilters}\n              onClose={() => (showFilters = false)}\n            />\n          {/if}\n        </div>\n      {/if}\n      <!-- Close button -->\n      <button\n        type=\"button\"\n        class=\"p-1.5 rounded hover:bg-white/10 transition-colors text-white/50 hover:text-white/70\"\n        onclick={onClose}\n        title=\"Close model picker\"\n      >\n        <svg class=\"w-5 h-5\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n          <path\n            d=\"M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z\"\n          />\n        </svg>\n      </button>\n    </div>\n\n    <!-- Body -->\n    <div class=\"flex flex-1 overflow-hidden\">\n      <!-- Family sidebar -->\n      <FamilySidebar\n        families={uniqueFamilies}\n        {selectedFamily}\n        {hasFavorites}\n        hasRecents={hasRecentsTab}\n        onSelect={(family) => (selectedFamily = family)}\n      />\n\n      <!-- Model list -->\n      <div class=\"flex-1 overflow-y-auto scrollbar-hide flex flex-col\">\n        {#if selectedFamily === \"huggingface\"}\n          <!-- HuggingFace Hub view -->\n          <div class=\"flex-1 flex flex-col min-h-0\">\n            <!-- Section header -->\n            <div\n              class=\"sticky top-0 z-10 px-3 py-2 bg-exo-dark-gray/95 border-b border-exo-yellow/10\"\n            >\n              <span class=\"text-xs font-mono text-white/40\">\n                {#if hfSearchQuery.length >= 2}\n                  Search results for \"{hfSearchQuery}\"\n                {:else}\n                  Trending on mlx-community\n                {/if}\n              </span>\n            </div>\n\n            <!-- Results list -->\n            <div class=\"flex-1 overflow-y-auto scrollbar-hide\">\n              {#if hfIsLoadingTrending && hfTrendingModels.length === 0}\n                <div\n                  class=\"flex items-center justify-center py-12 text-white/40\"\n                >\n                  <span\n                    class=\"w-5 h-5 border-2 border-orange-400 border-t-transparent rounded-full animate-spin mr-2\"\n                  ></span>\n                  <span class=\"font-mono text-sm\"\n                    >Loading trending models...</span\n                  >\n                </div>\n              {:else if hfDisplayModels.length === 0}\n                <div\n                  class=\"flex flex-col items-center justify-center py-12 text-white/40\"\n                >\n                  <svg\n                    class=\"w-10 h-10 mb-2\"\n                    viewBox=\"0 0 24 24\"\n                    fill=\"currentColor\"\n                  >\n                    <path\n                      d=\"M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 13.5c-.83 0-1.5-.67-1.5-1.5s.67-1.5 1.5-1.5 1.5.67 1.5 1.5-.67 1.5-1.5 1.5zm4 0c-.83 0-1.5-.67-1.5-1.5s.67-1.5 1.5-1.5 1.5.67 1.5 1.5-.67 1.5-1.5 1.5zm2-4.5H8c0-2.21 1.79-4 4-4s4 1.79 4 4z\"\n                    />\n                  </svg>\n                  <p class=\"font-mono text-sm\">No models found</p>\n                  {#if hfSearchQuery}\n                    <p class=\"font-mono text-xs mt-1\">\n                      Try a different search term\n                    </p>\n                  {/if}\n                </div>\n              {:else}\n                {#each hfDisplayModels as model}\n                  <HuggingFaceResultItem\n                    {model}\n                    isAdded={existingModelIds.has(model.id)}\n                    isAdding={addingModelId === model.id}\n                    onAdd={() => handleAddModel(model.id)}\n                    onSelect={() => handleSelectHfModel(model.id)}\n                    downloadedOnNodes={downloadsData\n                      ? getNodesWithModelDownloaded(\n                          downloadsData,\n                          model.id,\n                        ).map(getNodeName)\n                      : []}\n                  />\n                {/each}\n              {/if}\n            </div>\n\n            <!-- Manual input footer -->\n            <div\n              class=\"sticky bottom-0 border-t border-exo-yellow/10 bg-exo-dark-gray p-3\"\n            >\n              {#if addModelError}\n                <div\n                  class=\"bg-red-500/10 border border-red-500/30 rounded px-3 py-2 mb-2\"\n                >\n                  <p class=\"text-red-400 text-xs font-mono break-words\">\n                    {addModelError}\n                  </p>\n                </div>\n              {/if}\n              <div class=\"flex gap-2\">\n                <input\n                  type=\"text\"\n                  class=\"flex-1 bg-exo-black/60 border border-exo-yellow/30 rounded px-3 py-1.5 text-xs font-mono text-white placeholder-white/30 focus:outline-none focus:border-exo-yellow/50\"\n                  placeholder=\"Or paste model ID directly...\"\n                  bind:value={manualModelId}\n                  onkeydown={(e) => {\n                    if (e.key === \"Enter\") handleAddManualModel();\n                  }}\n                />\n                <button\n                  type=\"button\"\n                  onclick={handleAddManualModel}\n                  disabled={!manualModelId.trim() || addingModelId !== null}\n                  class=\"px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-orange-500/10 text-orange-400 border border-orange-400/30 hover:bg-orange-500/20 transition-colors rounded disabled:opacity-50 disabled:cursor-not-allowed\"\n                >\n                  Add\n                </button>\n              </div>\n            </div>\n          </div>\n        {:else if selectedFamily === \"recents\"}\n          <!-- Recent models view -->\n          {#if filteredRecentGroups.length === 0}\n            <div\n              class=\"flex flex-col items-center justify-center h-full text-white/40 p-8\"\n            >\n              <svg\n                class=\"w-12 h-12 mb-3\"\n                viewBox=\"0 0 24 24\"\n                fill=\"currentColor\"\n              >\n                <path\n                  d=\"M13 3a9 9 0 0 0-9 9H1l3.89 3.89.07.14L9 12H6c0-3.87 3.13-7 7-7s7 3.13 7 7-3.13 7-7 7c-1.93 0-3.68-.79-4.94-2.06l-1.42 1.42A8.954 8.954 0 0 0 13 21a9 9 0 0 0 0-18zm-1 5v5l4.28 2.54.72-1.21-3.5-2.08V8H12z\"\n                />\n              </svg>\n              <p class=\"font-mono text-sm\">\n                {searchQuery\n                  ? \"No matching recent models\"\n                  : \"No recently launched models\"}\n              </p>\n            </div>\n          {:else}\n            {#each filteredRecentGroups as group}\n              <ModelPickerGroup\n                {group}\n                isExpanded={expandedGroups.has(group.id)}\n                isFavorite={favorites.has(group.id)}\n                isHighlighted={justAddedModelId !== null &&\n                  group.variants.some((v) => v.id === justAddedModelId)}\n                {selectedModelId}\n                {canModelFit}\n                {getModelFitStatus}\n                onToggleExpand={() => toggleGroupExpanded(group.id)}\n                onSelectModel={handleSelect}\n                {onToggleFavorite}\n                onShowInfo={(g) => (infoGroup = g)}\n                downloadStatusMap={getVariantDownloadMap(group)}\n                launchedAt={recentTimestamps.get(group.variants[0]?.id ?? \"\")}\n                {instanceStatuses}\n              />\n            {/each}\n          {/if}\n        {:else if filteredGroups.length === 0}\n          <div\n            class=\"flex flex-col items-center justify-center h-full text-white/40 p-8\"\n          >\n            <svg class=\"w-12 h-12 mb-3\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n              <path\n                d=\"M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 15l-5-5 1.41-1.41L10 14.17l7.59-7.59L19 8l-9 9z\"\n              />\n            </svg>\n            <p class=\"font-mono text-sm\">No models found</p>\n            {#if hasActiveFilters || searchQuery}\n              <button\n                type=\"button\"\n                class=\"mt-2 text-xs text-exo-yellow hover:underline\"\n                onclick={() => {\n                  searchQuery = \"\";\n                  clearFilters();\n                }}\n              >\n                Clear filters\n              </button>\n            {/if}\n          </div>\n        {:else}\n          <!-- Recommended for your cluster -->\n          {#if recommendedGroups.length > 0 && otherGroups.length > 0 && !searchQuery.trim()}\n            <div\n              class=\"sticky top-0 z-10 flex items-center gap-2 px-3 py-2 bg-green-950/60 border-b border-green-500/20 backdrop-blur-sm\"\n            >\n              <svg\n                class=\"w-3.5 h-3.5 text-green-400 flex-shrink-0\"\n                fill=\"none\"\n                viewBox=\"0 0 24 24\"\n                stroke=\"currentColor\"\n                stroke-width=\"2\"\n              >\n                <path\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                  d=\"M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z\"\n                />\n              </svg>\n              <span\n                class=\"text-xs font-mono text-green-400 tracking-wider uppercase\"\n                >Recommended for your cluster</span\n              >\n              <span class=\"text-xs font-mono text-green-400/50\"\n                >— fits in available memory</span\n              >\n            </div>\n          {/if}\n          {#each recommendedGroups as group}\n            <ModelPickerGroup\n              {group}\n              isExpanded={expandedGroups.has(group.id)}\n              isFavorite={favorites.has(group.id)}\n              isHighlighted={justAddedModelId !== null &&\n                group.variants.some((v) => v.id === justAddedModelId)}\n              {selectedModelId}\n              {canModelFit}\n              {getModelFitStatus}\n              onToggleExpand={() => toggleGroupExpanded(group.id)}\n              onSelectModel={handleSelect}\n              {onToggleFavorite}\n              onShowInfo={(g) => (infoGroup = g)}\n              downloadStatusMap={getVariantDownloadMap(group)}\n              {instanceStatuses}\n            />\n          {/each}\n          <!-- Other models -->\n          {#if otherGroups.length > 0 && recommendedGroups.length > 0 && !searchQuery.trim()}\n            <div\n              class=\"sticky top-0 z-10 flex items-center gap-2 px-3 py-2 bg-exo-dark-gray/80 border-y border-exo-medium-gray/20 backdrop-blur-sm\"\n            >\n              <span\n                class=\"text-xs font-mono text-white/40 tracking-wider uppercase\"\n                >Other models</span\n              >\n            </div>\n          {/if}\n          {#each otherGroups as group}\n            <ModelPickerGroup\n              {group}\n              isExpanded={expandedGroups.has(group.id)}\n              isFavorite={favorites.has(group.id)}\n              isHighlighted={justAddedModelId !== null &&\n                group.variants.some((v) => v.id === justAddedModelId)}\n              {selectedModelId}\n              {canModelFit}\n              {getModelFitStatus}\n              onToggleExpand={() => toggleGroupExpanded(group.id)}\n              onSelectModel={handleSelect}\n              {onToggleFavorite}\n              onShowInfo={(g) => (infoGroup = g)}\n              downloadStatusMap={getVariantDownloadMap(group)}\n              {instanceStatuses}\n            />\n          {/each}\n          <!-- Inline HuggingFace search results (shown when no local results match) -->\n          {#if filteredGroups.length === 0 && searchQuery.trim().length >= 2 && selectedFamily !== \"huggingface\" && selectedFamily !== \"recents\" && selectedFamily !== \"favorites\"}\n            {#if mainSearchHfLoading}\n              <div\n                class=\"flex items-center gap-2 px-3 py-2 border-t border-orange-400/20 bg-orange-950/20\"\n              >\n                <span\n                  class=\"w-4 h-4 border-2 border-orange-400 border-t-transparent rounded-full animate-spin\"\n                ></span>\n                <span class=\"text-xs font-mono text-orange-400/60\"\n                  >Searching HuggingFace...</span\n                >\n              </div>\n            {:else if mainSearchHfResults.length > 0}\n              <div\n                class=\"sticky top-0 z-10 flex items-center gap-2 px-3 py-2 bg-orange-950/30 border-y border-orange-400/20 backdrop-blur-sm\"\n              >\n                <span\n                  class=\"text-xs font-mono text-orange-400 tracking-wider uppercase\"\n                  >From HuggingFace</span\n                >\n              </div>\n              {#each mainSearchHfResults as model}\n                <HuggingFaceResultItem\n                  {model}\n                  isAdded={existingModelIds.has(model.id)}\n                  isAdding={addingModelId === model.id}\n                  onAdd={() => handleAddModel(model.id)}\n                  onSelect={() => handleSelectHfModel(model.id)}\n                  downloadedOnNodes={downloadsData\n                    ? getNodesWithModelDownloaded(downloadsData, model.id).map(\n                        getNodeName,\n                      )\n                    : []}\n                />\n              {/each}\n              <button\n                type=\"button\"\n                class=\"w-full px-3 py-2 text-xs font-mono text-orange-400/60 hover:text-orange-400 hover:bg-orange-500/10 transition-colors text-center\"\n                onclick={() => {\n                  hfSearchQuery = searchQuery;\n                  searchHuggingFace(searchQuery);\n                  selectedFamily = \"huggingface\";\n                }}\n              >\n                See all results on Hub\n              </button>\n            {/if}\n          {/if}\n        {/if}\n      </div>\n    </div>\n\n    <!-- Footer with active filters indicator -->\n    {#if hasActiveFilters}\n      <div\n        class=\"flex items-center gap-2 px-3 py-2 border-t border-exo-yellow/10 bg-exo-medium-gray/20 text-xs font-mono text-white/50\"\n      >\n        <span>Filters:</span>\n        {#each filters.capabilities as cap}\n          <span class=\"px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded\"\n            >{cap}</span\n          >\n        {/each}\n        {#if filters.downloadedOnly}\n          <span class=\"px-1.5 py-0.5 bg-green-500/20 text-green-400 rounded\"\n            >Downloaded</span\n          >\n        {/if}\n        {#if filters.readyOnly}\n          <span class=\"px-1.5 py-0.5 bg-green-500/20 text-green-400 rounded\"\n            >Ready</span\n          >\n        {/if}\n        {#if filters.sizeRange}\n          <span class=\"px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded\">\n            {filters.sizeRange.min}GB - {filters.sizeRange.max}GB\n          </span>\n        {/if}\n        <button\n          type=\"button\"\n          class=\"ml-auto text-white/40 hover:text-white/60\"\n          onclick={clearFilters}\n        >\n          Clear all\n        </button>\n      </div>\n    {/if}\n  </div>\n\n  <!-- Info modal -->\n  {#if infoGroup}\n    <div\n      class=\"fixed inset-0 z-[60] bg-black/60\"\n      transition:fade={{ duration: 150 }}\n      onclick={() => (infoGroup = null)}\n      role=\"presentation\"\n    ></div>\n    <div\n      class=\"fixed z-[60] top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(80vw,400px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl p-4\"\n      transition:fly={{ y: 10, duration: 200, easing: cubicOut }}\n      role=\"dialog\"\n      aria-modal=\"true\"\n    >\n      <div class=\"flex items-start justify-between mb-3\">\n        <h3 class=\"font-mono text-lg text-white\">{infoGroup.name}</h3>\n        <button\n          type=\"button\"\n          class=\"p-1 rounded hover:bg-white/10 transition-colors text-white/50\"\n          onclick={() => (infoGroup = null)}\n          title=\"Close model details\"\n          aria-label=\"Close info dialog\"\n        >\n          <svg class=\"w-4 h-4\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n            <path\n              d=\"M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z\"\n            />\n          </svg>\n        </button>\n      </div>\n      <div class=\"space-y-2 text-xs font-mono\">\n        <div class=\"flex items-center gap-2\">\n          <span class=\"text-white/40\">Family:</span>\n          <span class=\"text-white/70\">{infoGroup.family || \"Unknown\"}</span>\n        </div>\n        <div class=\"flex items-center gap-2\">\n          <span class=\"text-white/40\">Capabilities:</span>\n          <span class=\"text-white/70\">{infoGroup.capabilities.join(\", \")}</span>\n        </div>\n        <div class=\"flex items-center gap-2\">\n          <span class=\"text-white/40\">Variants:</span>\n          <span class=\"text-white/70\">{infoGroup.variants.length}</span>\n        </div>\n        {#if infoGroup.variants.length > 0}\n          <div class=\"mt-3 pt-3 border-t border-exo-yellow/10\">\n            <span class=\"text-white/40\">Available quantizations:</span>\n            <div class=\"flex flex-wrap gap-1 mt-1\">\n              {#each infoGroup.variants as variant}\n                <span\n                  class=\"px-1.5 py-0.5 bg-white/10 text-white/60 rounded text-[10px]\"\n                >\n                  {variant.quantization || \"default\"} ({Math.round(\n                    (variant.storage_size_megabytes || 0) / 1024,\n                  )}GB)\n                </span>\n              {/each}\n            </div>\n          </div>\n        {/if}\n        {#if getGroupDownloadAvailability(infoGroup)?.nodeNames?.length}\n          {@const infoDownload = getGroupDownloadAvailability(infoGroup)}\n          {#if infoDownload}\n            <div class=\"mt-3 pt-3 border-t border-exo-yellow/10\">\n              <div class=\"flex items-center gap-2 mb-1\">\n                <svg\n                  class=\"w-3.5 h-3.5\"\n                  viewBox=\"0 0 24 24\"\n                  fill=\"none\"\n                  stroke=\"currentColor\"\n                  stroke-width=\"2\"\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                >\n                  <path\n                    class=\"text-white/40\"\n                    d=\"M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z\"\n                  />\n                  <path class=\"text-green-400\" d=\"m9 13 2 2 4-4\" />\n                </svg>\n                <span class=\"text-white/40\">Downloaded on:</span>\n              </div>\n              <div class=\"flex flex-wrap gap-1 mt-1\">\n                {#each infoDownload.nodeNames as nodeName}\n                  <span\n                    class=\"px-1.5 py-0.5 bg-green-500/10 text-green-400/80 border border-green-500/20 rounded text-[10px]\"\n                  >\n                    {nodeName}\n                  </span>\n                {/each}\n              </div>\n            </div>\n          {/if}\n        {/if}\n      </div>\n    </div>\n  {/if}\n{/if}\n"
  },
  {
    "path": "dashboard/src/lib/components/PrefillProgressBar.svelte",
    "content": "<script lang=\"ts\">\n  import type { PrefillProgress } from \"$lib/stores/app.svelte\";\n\n  interface Props {\n    progress: PrefillProgress;\n    class?: string;\n  }\n\n  let { progress, class: className = \"\" }: Props = $props();\n\n  const percentage = $derived(\n    progress.total > 0\n      ? Math.round((progress.processed / progress.total) * 100)\n      : 0,\n  );\n\n  const etaText = $derived.by(() => {\n    if (progress.processed <= 0 || progress.total <= 0) return null;\n    const elapsedMs = performance.now() - progress.startedAt;\n    if (elapsedMs < 200) return null; // need a minimum sample window\n    const tokensPerMs = progress.processed / elapsedMs;\n    const remainingTokens = progress.total - progress.processed;\n    const remainingMs = remainingTokens / tokensPerMs;\n    const remainingSec = Math.ceil(remainingMs / 1000);\n    if (remainingSec <= 0) return null;\n    if (remainingSec < 60) return `~${remainingSec}s remaining`;\n    const mins = Math.floor(remainingSec / 60);\n    const secs = remainingSec % 60;\n    return `~${mins}m ${secs}s remaining`;\n  });\n\n  function formatTokenCount(count: number | undefined): string {\n    if (count == null) return \"0\";\n    if (count >= 1000) {\n      return `${(count / 1000).toFixed(1)}k`;\n    }\n    return count.toString();\n  }\n</script>\n\n<div class=\"prefill-progress {className}\">\n  <div\n    class=\"flex items-center justify-between text-xs text-exo-light-gray mb-1\"\n  >\n    <span>Processing prompt</span>\n    <span class=\"font-mono\">\n      {formatTokenCount(progress.processed)} / {formatTokenCount(\n        progress.total,\n      )} tokens\n    </span>\n  </div>\n  <div class=\"h-1.5 bg-exo-black/60 rounded-full overflow-hidden\">\n    <div\n      class=\"h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out\"\n      style=\"width: {percentage}%\"\n    ></div>\n  </div>\n  <div\n    class=\"flex items-center justify-between text-xs text-exo-light-gray/70 mt-0.5 font-mono\"\n  >\n    <span>{etaText ?? \"\"}</span>\n    <span>{percentage}%</span>\n  </div>\n</div>\n\n<style>\n  .prefill-progress {\n    width: 100%;\n  }\n</style>\n"
  },
  {
    "path": "dashboard/src/lib/components/ToastContainer.svelte",
    "content": "<script lang=\"ts\">\n  import { toasts, dismissToast, type Toast } from \"$lib/stores/toast.svelte\";\n  import { fly, fade } from \"svelte/transition\";\n  import { flip } from \"svelte/animate\";\n\n  const items = $derived(toasts());\n\n  const typeStyles: Record<\n    Toast[\"type\"],\n    { border: string; icon: string; iconColor: string }\n  > = {\n    success: {\n      border: \"border-l-green-500\",\n      icon: \"M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z\",\n      iconColor: \"text-green-400\",\n    },\n    error: {\n      border: \"border-l-red-500\",\n      icon: \"M12 9v3.75m9-.75a9 9 0 11-18 0 9 9 0 0118 0zm-9 3.75h.008v.008H12v-.008z\",\n      iconColor: \"text-red-400\",\n    },\n    warning: {\n      border: \"border-l-yellow-500\",\n      icon: \"M12 9v3.75m-9.303 3.376c-.866 1.5.217 3.374 1.948 3.374h14.71c1.73 0 2.813-1.874 1.948-3.374L13.949 3.378c-.866-1.5-3.032-1.5-3.898 0L2.697 16.126z\",\n      iconColor: \"text-yellow-400\",\n    },\n    info: {\n      border: \"border-l-blue-500\",\n      icon: \"M11.25 11.25l.041-.02a.75.75 0 011.063.852l-.708 2.836a.75.75 0 001.063.853l.041-.021M21 12a9 9 0 11-18 0 9 9 0 0118 0zm-9-3.75h.008v.008H12V8.25z\",\n      iconColor: \"text-blue-400\",\n    },\n  };\n</script>\n\n{#if items.length > 0}\n  <div\n    class=\"fixed bottom-6 right-6 z-[9999] flex flex-col gap-2 pointer-events-none\"\n    role=\"log\"\n    aria-live=\"polite\"\n    aria-label=\"Notifications\"\n  >\n    {#each items as toast (toast.id)}\n      {@const style = typeStyles[toast.type]}\n      <div\n        class=\"pointer-events-auto max-w-sm w-80 bg-exo-dark-gray/95 backdrop-blur-sm border border-exo-medium-gray/60 border-l-[3px] {style.border} rounded shadow-lg shadow-black/40\"\n        in:fly={{ x: 80, duration: 250 }}\n        out:fade={{ duration: 150 }}\n        animate:flip={{ duration: 200 }}\n        role=\"alert\"\n      >\n        <div class=\"flex items-start gap-3 px-4 py-3\">\n          <!-- Icon -->\n          <svg\n            class=\"w-5 h-5 flex-shrink-0 mt-0.5 {style.iconColor}\"\n            fill=\"none\"\n            viewBox=\"0 0 24 24\"\n            stroke-width=\"1.5\"\n            stroke=\"currentColor\"\n          >\n            <path\n              stroke-linecap=\"round\"\n              stroke-linejoin=\"round\"\n              d={style.icon}\n            />\n          </svg>\n\n          <!-- Message -->\n          <p class=\"flex-1 text-sm text-white/90 font-mono leading-snug\">\n            {toast.message}\n          </p>\n\n          <!-- Dismiss button -->\n          <button\n            onclick={() => dismissToast(toast.id)}\n            class=\"flex-shrink-0 p-0.5 text-white/40 hover:text-white/80 transition-colors cursor-pointer\"\n            aria-label=\"Dismiss notification\"\n          >\n            <svg\n              class=\"w-4 h-4\"\n              fill=\"none\"\n              viewBox=\"0 0 24 24\"\n              stroke-width=\"2\"\n              stroke=\"currentColor\"\n            >\n              <path\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n                d=\"M6 18L18 6M6 6l12 12\"\n              />\n            </svg>\n          </button>\n        </div>\n\n        <!-- Auto-dismiss progress bar -->\n        {#if toast.duration > 0}\n          <div class=\"h-0.5 bg-white/5 rounded-b overflow-hidden\">\n            <div\n              class=\"h-full {style.border.replace('border-l-', 'bg-')}/60\"\n              style=\"animation: shrink {toast.duration}ms linear forwards\"\n            ></div>\n          </div>\n        {/if}\n      </div>\n    {/each}\n  </div>\n{/if}\n\n<style>\n  @keyframes shrink {\n    from {\n      width: 100%;\n    }\n    to {\n      width: 0%;\n    }\n  }\n</style>\n"
  },
  {
    "path": "dashboard/src/lib/components/TokenHeatmap.svelte",
    "content": "<script lang=\"ts\">\n  import type { TokenData } from \"$lib/stores/app.svelte\";\n\n  interface Props {\n    tokens: TokenData[];\n    class?: string;\n    isGenerating?: boolean;\n    onRegenerateFrom?: (tokenIndex: number) => void;\n  }\n\n  let {\n    tokens,\n    class: className = \"\",\n    isGenerating = false,\n    onRegenerateFrom,\n  }: Props = $props();\n\n  // Tooltip state - track both token data and index\n  let hoveredTokenIndex = $state<number | null>(null);\n  let hoveredPosition = $state<{ x: number; y: number } | null>(null);\n  let isTooltipHovered = $state(false);\n  let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;\n\n  // Derive the hovered token from the index (stable across re-renders)\n  const hoveredToken = $derived(\n    hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]\n      ? {\n          token: tokens[hoveredTokenIndex],\n          index: hoveredTokenIndex,\n          ...hoveredPosition,\n        }\n      : null,\n  );\n\n  /**\n   * Get confidence styling based on probability.\n   * Following Apple design principles: high confidence tokens blend in,\n   * only uncertainty draws attention.\n   */\n  function getConfidenceClass(probability: number): string {\n    if (probability > 0.8) return \"text-inherit\"; // Expected tokens - blend in\n    if (probability > 0.5) return \"bg-gray-500/10 text-inherit\"; // Slight hint\n    if (probability > 0.2) return \"bg-amber-500/15 text-amber-200/90\"; // Subtle warmth\n    return \"bg-red-500/20 text-red-200/90\"; // Draws attention\n  }\n\n  /**\n   * Get border/underline styling for uncertain tokens\n   */\n  function getBorderClass(probability: number): string {\n    if (probability > 0.8) return \"border-transparent\"; // No border for expected\n    if (probability > 0.5) return \"border-gray-500/20\";\n    if (probability > 0.2) return \"border-amber-500/30\";\n    return \"border-red-500/40\";\n  }\n\n  function clearHideTimeout() {\n    if (hideTimeoutId) {\n      clearTimeout(hideTimeoutId);\n      hideTimeoutId = null;\n    }\n  }\n\n  function handleMouseEnter(\n    event: MouseEvent,\n    token: TokenData,\n    index: number,\n  ) {\n    clearHideTimeout();\n    const rects = (event.target as HTMLElement).getClientRects();\n    let rect = rects[0];\n    for (let j = 0; j < rects.length; j++) {\n      if (event.clientY >= rects[j].top && event.clientY <= rects[j].bottom) {\n        rect = rects[j];\n        break;\n      }\n    }\n    hoveredTokenIndex = index;\n    hoveredPosition = {\n      x: rect.left + rect.width / 2,\n      y: rect.top - 10,\n    };\n  }\n\n  function handleMouseLeave() {\n    clearHideTimeout();\n    // Use longer delay during generation to account for re-renders\n    const delay = isGenerating ? 300 : 200;\n    hideTimeoutId = setTimeout(() => {\n      if (!isTooltipHovered) {\n        hoveredTokenIndex = null;\n        hoveredPosition = null;\n      }\n    }, delay);\n  }\n\n  function handleTooltipEnter() {\n    clearHideTimeout();\n    isTooltipHovered = true;\n  }\n\n  function handleTooltipLeave() {\n    isTooltipHovered = false;\n    hoveredTokenIndex = null;\n    hoveredPosition = null;\n  }\n\n  function handleRegenerate() {\n    if (hoveredToken && onRegenerateFrom) {\n      const indexToRegenerate = hoveredToken.index;\n      // Clear hover state immediately\n      hoveredTokenIndex = null;\n      hoveredPosition = null;\n      isTooltipHovered = false;\n      // Call regenerate\n      onRegenerateFrom(indexToRegenerate);\n    }\n  }\n\n  function formatProbability(prob: number): string {\n    return (prob * 100).toFixed(1) + \"%\";\n  }\n\n  function formatLogprob(logprob: number): string {\n    return logprob.toFixed(3);\n  }\n\n  function getProbabilityColor(probability: number): string {\n    if (probability > 0.8) return \"text-gray-300\";\n    if (probability > 0.5) return \"text-gray-400\";\n    if (probability > 0.2) return \"text-amber-400\";\n    return \"text-red-400\";\n  }\n</script>\n\n<div class=\"token-heatmap leading-relaxed {className}\">\n  {#each tokens as tokenData, i (i)}\n    <span\n      role=\"button\"\n      tabindex=\"0\"\n      class=\"token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(\n        tokenData.probability,\n      )} {getBorderClass(tokenData.probability)} hover:opacity-80\"\n      onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}\n      onmouseleave={handleMouseLeave}>{tokenData.token}</span\n    >\n  {/each}\n</div>\n\n<!-- Tooltip -->\n{#if hoveredToken}\n  <div\n    class=\"fixed z-50 pb-2\"\n    style=\"left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);\"\n    onmouseenter={handleTooltipEnter}\n    onmouseleave={handleTooltipLeave}\n  >\n    <div\n      class=\"bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48\"\n    >\n      <!-- Token info -->\n      <div class=\"mb-2\">\n        <span class=\"text-gray-500 text-xs\">Token:</span>\n        <span class=\"text-white font-mono ml-1\"\n          >\"{hoveredToken.token.token}\"</span\n        >\n        <span class=\"{getProbabilityColor(hoveredToken.token.probability)} ml-2\"\n          >{formatProbability(hoveredToken.token.probability)}</span\n        >\n      </div>\n\n      <div class=\"text-gray-400 text-xs mb-1\">\n        logprob: <span class=\"text-gray-300 font-mono\"\n          >{formatLogprob(hoveredToken.token.logprob)}</span\n        >\n      </div>\n\n      <!-- Top alternatives -->\n      {#if hoveredToken.token.topLogprobs.length > 0}\n        <div class=\"border-t border-gray-700/50 mt-2 pt-2\">\n          <div class=\"text-gray-500 text-xs mb-1\">Alternatives:</div>\n          {#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}\n            {@const altProb = Math.exp(alt.logprob)}\n            <div class=\"flex justify-between items-center text-xs py-0.5\">\n              <span class=\"text-gray-300 font-mono truncate max-w-24\"\n                >\"{alt.token}\"</span\n              >\n              <span class=\"text-gray-400 ml-2\"\n                >{formatProbability(altProb)}</span\n              >\n            </div>\n          {/each}\n        </div>\n      {/if}\n\n      <!-- Regenerate button -->\n      {#if onRegenerateFrom}\n        <button\n          onclick={handleRegenerate}\n          class=\"w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer\"\n        >\n          <svg\n            class=\"w-3 h-3\"\n            fill=\"none\"\n            viewBox=\"0 0 24 24\"\n            stroke=\"currentColor\"\n          >\n            <path\n              stroke-linecap=\"round\"\n              stroke-linejoin=\"round\"\n              stroke-width=\"2\"\n              d=\"M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15\"\n            />\n          </svg>\n          Regenerate from here\n        </button>\n      {/if}\n    </div>\n    <!-- Arrow -->\n    <div class=\"absolute left-1/2 -translate-x-1/2 top-full\">\n      <div class=\"border-8 border-transparent border-t-gray-900\"></div>\n    </div>\n  </div>\n{/if}\n\n<style>\n  .token-heatmap {\n    word-wrap: break-word;\n    white-space: pre-wrap;\n  }\n\n  .token-span {\n    margin: 0;\n    border-width: 1px;\n  }\n</style>\n"
  },
  {
    "path": "dashboard/src/lib/components/TopologyGraph.svelte",
    "content": "<script lang=\"ts\">\n  import { onMount, onDestroy } from \"svelte\";\n  import * as d3 from \"d3\";\n  import {\n    topologyData,\n    isTopologyMinimized,\n    debugMode,\n    nodeThunderboltBridge,\n    nodeRdmaCtl,\n    nodeIdentities,\n    type NodeInfo,\n  } from \"$lib/stores/app.svelte\";\n\n  interface Props {\n    class?: string;\n    highlightedNodes?: Set<string>;\n    filteredNodes?: Set<string>;\n    onNodeClick?: (nodeId: string) => void;\n  }\n\n  let {\n    class: className = \"\",\n    highlightedNodes = new Set(),\n    filteredNodes = new Set(),\n    onNodeClick,\n  }: Props = $props();\n\n  let svgContainer: SVGSVGElement | undefined = $state();\n  let resizeObserver: ResizeObserver | undefined;\n  let hoveredNodeId = $state<string | null>(null);\n\n  const isMinimized = $derived(isTopologyMinimized());\n  const data = $derived(topologyData());\n  const debugEnabled = $derived(debugMode());\n  const tbBridgeData = $derived(nodeThunderboltBridge());\n  const rdmaCtlData = $derived(nodeRdmaCtl());\n  const identitiesData = $derived(nodeIdentities());\n\n  function getNodeLabel(nodeId: string): string {\n    const node = data?.nodes?.[nodeId];\n    return node?.friendly_name || nodeId.slice(0, 8);\n  }\n\n  function getInterfaceLabel(\n    nodeId: string,\n    ip?: string,\n  ): { label: string; missing: boolean } {\n    if (!ip) return { label: \"?\", missing: true };\n\n    // Strip port if present (e.g., \"192.168.1.1:8080\" -> \"192.168.1.1\")\n    const cleanIp =\n      ip.includes(\":\") && !ip.includes(\"[\") ? ip.split(\":\")[0] : ip;\n\n    // Helper to check a node's interfaces\n    function checkNode(node: NodeInfo | undefined): string | null {\n      if (!node) return null;\n\n      const matchFromInterfaces = node.network_interfaces?.find((iface) =>\n        (iface.addresses || []).some((addr) => addr === cleanIp || addr === ip),\n      );\n      if (matchFromInterfaces?.name) {\n        return matchFromInterfaces.name;\n      }\n\n      if (node.ip_to_interface) {\n        const mapped =\n          node.ip_to_interface[cleanIp] ||\n          (ip ? node.ip_to_interface[ip] : undefined);\n        if (mapped && mapped.trim().length > 0) {\n          return mapped;\n        }\n      }\n      return null;\n    }\n\n    // Try specified node first\n    const result = checkNode(data?.nodes?.[nodeId]);\n    if (result) return { label: result, missing: false };\n\n    // Fallback: search all nodes for this IP\n    for (const [, otherNode] of Object.entries(data?.nodes || {})) {\n      const otherResult = checkNode(otherNode);\n      if (otherResult) return { label: otherResult, missing: false };\n    }\n\n    return { label: \"?\", missing: true };\n  }\n\n  function wrapLine(text: string, maxLen: number): string[] {\n    if (text.length <= maxLen) return [text];\n    const words = text.split(\" \");\n    const lines: string[] = [];\n    let current = \"\";\n    for (const word of words) {\n      if (word.length > maxLen) {\n        if (current) {\n          lines.push(current);\n          current = \"\";\n        }\n        for (let i = 0; i < word.length; i += maxLen) {\n          lines.push(word.slice(i, i + maxLen));\n        }\n      } else if ((current + \" \" + word).trim().length > maxLen) {\n        lines.push(current);\n        current = word;\n      } else {\n        current = current ? `${current} ${word}` : word;\n      }\n    }\n    if (current) lines.push(current);\n    return lines;\n  }\n\n  // Apple logo path for MacBook Pro screen\n  const APPLE_LOGO_PATH =\n    \"M788.1 340.9c-5.8 4.5-108.2 62.2-108.2 190.5 0 148.4 130.3 200.9 134.2 202.2-.6 3.2-20.7 71.9-68.7 141.9-42.8 61.6-87.5 123.1-155.5 123.1s-85.5-39.5-164-39.5c-76.5 0-103.7 40.8-165.9 40.8s-105.6-57-155.5-127C46.7 790.7 0 663 0 541.8c0-194.4 126.4-297.5 250.8-297.5 66.1 0 121.2 43.4 162.7 43.4 39.5 0 101.1-46 176.3-46 28.5 0 130.9 2.6 198.3 99.2zm-234-181.5c31.1-36.9 53.1-88.1 53.1-139.3 0-7.1-.6-14.3-1.9-20.1-50.6 1.9-110.8 33.7-147.1 75.8-28.5 32.4-55.1 83.6-55.1 135.5 0 7.8 1.3 15.6 1.9 18.1 3.2.6 8.4 1.3 13.6 1.3 45.4 0 102.5-30.4 135.5-71.3z\";\n  const LOGO_NATIVE_WIDTH = 814;\n  const LOGO_NATIVE_HEIGHT = 1000;\n\n  function formatBytes(bytes: number, decimals = 1): string {\n    if (!bytes || bytes === 0) return \"0B\";\n    const k = 1024;\n    const sizes = [\"B\", \"KB\", \"MB\", \"GB\", \"TB\"];\n    const i = Math.floor(Math.log(bytes) / Math.log(k));\n    return parseFloat((bytes / Math.pow(k, i)).toFixed(decimals)) + sizes[i];\n  }\n\n  function getTemperatureColor(temp: number): string {\n    // Default for N/A temp - light gray\n    if (isNaN(temp) || temp === null) return \"rgba(179, 179, 179, 0.8)\";\n\n    const coolTemp = 45; // Temp for pure blue\n    const midTemp = 57.5; // Temp for pure yellow\n    const hotTemp = 75; // Temp for pure red\n\n    const coolColor = { r: 93, g: 173, b: 226 }; // #5DADE2 (Blue)\n    const midColor = { r: 255, g: 215, b: 0 }; // #FFD700 (Yellow)\n    const hotColor = { r: 244, g: 67, b: 54 }; // #F44336 (Red)\n\n    let r: number, g: number, b: number;\n\n    if (temp <= coolTemp) {\n      ({ r, g, b } = coolColor);\n    } else if (temp <= midTemp) {\n      const ratio = (temp - coolTemp) / (midTemp - coolTemp);\n      r = Math.round(coolColor.r * (1 - ratio) + midColor.r * ratio);\n      g = Math.round(coolColor.g * (1 - ratio) + midColor.g * ratio);\n      b = Math.round(coolColor.b * (1 - ratio) + midColor.b * ratio);\n    } else if (temp < hotTemp) {\n      const ratio = (temp - midTemp) / (hotTemp - midTemp);\n      r = Math.round(midColor.r * (1 - ratio) + hotColor.r * ratio);\n      g = Math.round(midColor.g * (1 - ratio) + hotColor.g * ratio);\n      b = Math.round(midColor.b * (1 - ratio) + hotColor.b * ratio);\n    } else {\n      ({ r, g, b } = hotColor);\n    }\n\n    return `rgb(${r}, ${g}, ${b})`;\n  }\n\n  function renderGraph() {\n    if (!svgContainer || !data) return;\n\n    d3.select(svgContainer).selectAll(\"*\").remove();\n\n    const nodes = data.nodes || {};\n    const edges = data.edges || [];\n    const nodeIds = Object.keys(nodes);\n\n    const rect = svgContainer.getBoundingClientRect();\n    const width = rect.width;\n    const height = rect.height;\n    const centerX = width / 2;\n    const centerY = height / 2;\n\n    const svg = d3.select(svgContainer);\n\n    // Add defs for clip paths and filters\n    const defs = svg.append(\"defs\");\n\n    // Glow filter\n    const glowFilter = defs\n      .append(\"filter\")\n      .attr(\"id\", \"glow\")\n      .attr(\"x\", \"-50%\")\n      .attr(\"y\", \"-50%\")\n      .attr(\"width\", \"200%\")\n      .attr(\"height\", \"200%\");\n    glowFilter\n      .append(\"feGaussianBlur\")\n      .attr(\"stdDeviation\", \"2\")\n      .attr(\"result\", \"coloredBlur\");\n    const glowMerge = glowFilter.append(\"feMerge\");\n    glowMerge.append(\"feMergeNode\").attr(\"in\", \"coloredBlur\");\n    glowMerge.append(\"feMergeNode\").attr(\"in\", \"SourceGraphic\");\n\n    // Arrowhead marker for directional edges\n    const marker = defs\n      .append(\"marker\")\n      .attr(\"id\", \"arrowhead\")\n      .attr(\"viewBox\", \"0 0 10 10\")\n      .attr(\"refX\", \"10\")\n      .attr(\"refY\", \"5\")\n      .attr(\"markerWidth\", \"11\")\n      .attr(\"markerHeight\", \"11\")\n      .attr(\"orient\", \"auto-start-reverse\");\n    marker\n      .append(\"path\")\n      .attr(\"d\", \"M 0 0 L 10 5 L 0 10\")\n      .attr(\"fill\", \"none\")\n      .attr(\"stroke\", \"var(--exo-light-gray, #B3B3B3)\")\n      .attr(\"stroke-width\", \"1.6\")\n      .attr(\"stroke-linecap\", \"round\")\n      .attr(\"stroke-linejoin\", \"round\")\n      .style(\"animation\", \"none\");\n\n    if (nodeIds.length === 0) {\n      svg\n        .append(\"text\")\n        .attr(\"x\", centerX)\n        .attr(\"y\", centerY)\n        .attr(\"text-anchor\", \"middle\")\n        .attr(\"dominant-baseline\", \"middle\")\n        .attr(\"fill\", \"rgba(255,215,0,0.4)\")\n        .attr(\"font-size\", isMinimized ? 10 : 12)\n        .attr(\"font-family\", \"SF Mono, monospace\")\n        .attr(\"letter-spacing\", \"0.1em\")\n        .text(\"AWAITING NODES\");\n      return;\n    }\n\n    const numNodes = nodeIds.length;\n    const minDimension = Math.min(width, height);\n\n    // Dynamic scaling - larger nodes for big displays\n    const sizeScale =\n      numNodes === 1 ? 1 : Math.max(0.6, 1 - (numNodes - 1) * 0.1);\n    const baseNodeRadius = isMinimized\n      ? Math.max(36, Math.min(60, minDimension * 0.22))\n      : Math.min(120, minDimension * 0.2);\n    const nodeRadius = baseNodeRadius * sizeScale;\n\n    // Orbit radius - balanced spacing for nodes\n    const circumference = numNodes * nodeRadius * 4;\n    const radiusFromCircumference = circumference / (2 * Math.PI);\n    const minOrbitRadius = Math.max(\n      radiusFromCircumference,\n      minDimension * 0.18,\n    );\n    const maxOrbitRadius = minDimension * 0.3;\n    const orbitRadius = isMinimized\n      ? Math.min(maxOrbitRadius, Math.max(minOrbitRadius, minDimension * 0.26))\n      : Math.min(\n          maxOrbitRadius,\n          Math.max(minOrbitRadius, minDimension * (0.22 + numNodes * 0.02)),\n        );\n\n    // Determine display mode based on space and node count\n    const showFullLabels = !isMinimized && numNodes <= 4;\n    const showCompactLabels = !isMinimized && numNodes > 4;\n\n    // Add padding for labels (top/bottom)\n    const topPadding = 70; // Space for \"NETWORK TOPOLOGY\" label and node names\n    const bottomPadding = 70; // Space for stats and bottom label\n    const safeCenterY = topPadding + (height - topPadding - bottomPadding) / 2;\n\n    // Calculate node positions\n    const nodesWithPositions = nodeIds.map((id, index) => {\n      if (numNodes === 1) {\n        // Single node: center it\n        return {\n          id,\n          data: nodes[id],\n          x: centerX,\n          y: safeCenterY,\n        };\n      }\n      // Distribute nodes around the orbit\n      // Start from top (-90 degrees) and go clockwise\n      const angle = (index / numNodes) * 2 * Math.PI - Math.PI / 2;\n      return {\n        id,\n        data: nodes[id],\n        x: centerX + orbitRadius * Math.cos(angle),\n        y: safeCenterY + orbitRadius * Math.sin(angle),\n      };\n    });\n\n    const positionById: Record<string, { x: number; y: number }> = {};\n    nodesWithPositions.forEach((n) => {\n      positionById[n.id] = { x: n.x, y: n.y };\n    });\n\n    // Draw edges\n    const linksGroup = svg.append(\"g\").attr(\"class\", \"links-group\");\n    const arrowsGroup = svg.append(\"g\").attr(\"class\", \"arrows-group\");\n    const debugLabelsGroup = svg.append(\"g\").attr(\"class\", \"debug-edge-labels\");\n\n    type ConnectionInfo = {\n      from: string;\n      to: string;\n      ip: string;\n      ifaceLabel: string;\n      missingIface: boolean;\n    };\n    type PairEntry = {\n      a: string;\n      b: string;\n      aToB: boolean;\n      bToA: boolean;\n      connections: ConnectionInfo[];\n    };\n    type DebugEdgeLabelEntry = {\n      connections: ConnectionInfo[];\n      isLeft: boolean;\n      isTop: boolean;\n      mx: number;\n      my: number;\n    };\n    const pairMap = new Map<string, PairEntry>();\n    const debugEdgeLabels: DebugEdgeLabelEntry[] = [];\n    edges.forEach((edge) => {\n      if (!edge.source || !edge.target || edge.source === edge.target) return;\n      if (!positionById[edge.source] || !positionById[edge.target]) return;\n\n      const a = edge.source < edge.target ? edge.source : edge.target;\n      const b = edge.source < edge.target ? edge.target : edge.source;\n      const key = `${a}|${b}`;\n      const entry = pairMap.get(key) || {\n        a,\n        b,\n        aToB: false,\n        bToA: false,\n        connections: [],\n      };\n\n      if (edge.source === a) entry.aToB = true;\n      else entry.bToA = true;\n\n      let ip: string;\n      let ifaceLabel: string;\n      let missingIface: boolean;\n\n      if (edge.sourceRdmaIface || edge.sinkRdmaIface) {\n        ip = \"RDMA\";\n        ifaceLabel = `${edge.sourceRdmaIface || \"?\"} \\u2192 ${edge.sinkRdmaIface || \"?\"}`;\n        missingIface = false;\n      } else {\n        ip = edge.sendBackIp || \"?\";\n        const ifaceInfo = getInterfaceLabel(edge.source, ip);\n        ifaceLabel = ifaceInfo.label;\n        missingIface = ifaceInfo.missing;\n      }\n\n      entry.connections.push({\n        from: edge.source,\n        to: edge.target,\n        ip,\n        ifaceLabel,\n        missingIface,\n      });\n      pairMap.set(key, entry);\n    });\n\n    pairMap.forEach((entry) => {\n      const posA = positionById[entry.a];\n      const posB = positionById[entry.b];\n      if (!posA || !posB) return;\n\n      // Base dashed line\n      linksGroup\n        .append(\"line\")\n        .attr(\"x1\", posA.x)\n        .attr(\"y1\", posA.y)\n        .attr(\"x2\", posB.x)\n        .attr(\"y2\", posB.y)\n        .attr(\"class\", \"graph-link\");\n\n      // Calculate midpoint and direction for arrows\n      const dx = posB.x - posA.x;\n      const dy = posB.y - posA.y;\n      const len = Math.hypot(dx, dy) || 1;\n      const ux = dx / len;\n      const uy = dy / len;\n      const mx = (posA.x + posB.x) / 2;\n      const my = (posA.y + posB.y) / 2;\n      const tipOffset = 16; // Distance from center for arrow tips\n      const carrier = 2; // Short segment length for arrow orientation\n\n      // Arrow A -> B (if connection exists in that direction)\n      if (entry.aToB) {\n        const tipX = mx - ux * tipOffset;\n        const tipY = my - uy * tipOffset;\n        arrowsGroup\n          .append(\"line\")\n          .attr(\"x1\", tipX - ux * carrier)\n          .attr(\"y1\", tipY - uy * carrier)\n          .attr(\"x2\", tipX)\n          .attr(\"y2\", tipY)\n          .attr(\"stroke\", \"none\")\n          .attr(\"fill\", \"none\")\n          .attr(\"marker-end\", \"url(#arrowhead)\");\n      }\n\n      // Arrow B -> A (if connection exists in that direction)\n      if (entry.bToA) {\n        const tipX = mx + ux * tipOffset;\n        const tipY = my + uy * tipOffset;\n        arrowsGroup\n          .append(\"line\")\n          .attr(\"x1\", tipX + ux * carrier)\n          .attr(\"y1\", tipY + uy * carrier)\n          .attr(\"x2\", tipX)\n          .attr(\"y2\", tipY)\n          .attr(\"stroke\", \"none\")\n          .attr(\"fill\", \"none\")\n          .attr(\"marker-end\", \"url(#arrowhead)\");\n      }\n\n      // Collect debug labels for later positioning at edges\n      if (debugEnabled && entry.connections.length > 0) {\n        // Determine which side of viewport based on edge midpoint\n        const isLeft = mx < centerX;\n        const isTop = my < safeCenterY;\n\n        // Store for batch rendering after all edges processed\n        debugEdgeLabels.push({\n          connections: entry.connections,\n          isLeft,\n          isTop,\n          mx,\n          my,\n        });\n      }\n    });\n\n    // Render debug labels at viewport edges/corners\n    if (debugEdgeLabels && debugEdgeLabels.length > 0) {\n      const fontSize = isMinimized ? 10 : 12;\n      const lineHeight = fontSize + 4;\n      const padding = 10;\n\n      // Helper to get arrow based on direction vector\n      function getArrow(fromId: string, toId: string): string {\n        const fromPos = positionById[fromId];\n        const toPos = positionById[toId];\n        if (!fromPos || !toPos) return \"→\";\n\n        const dirX = toPos.x - fromPos.x;\n        const dirY = toPos.y - fromPos.y;\n        const absX = Math.abs(dirX);\n        const absY = Math.abs(dirY);\n\n        if (absX > absY * 2) {\n          return dirX > 0 ? \"→\" : \"←\";\n        } else if (absY > absX * 2) {\n          return dirY > 0 ? \"↓\" : \"↑\";\n        } else {\n          if (dirX > 0 && dirY > 0) return \"↘\";\n          if (dirX > 0 && dirY < 0) return \"↗\";\n          if (dirX < 0 && dirY > 0) return \"↙\";\n          return \"↖\";\n        }\n      }\n\n      // Group by quadrant: topLeft, topRight, bottomLeft, bottomRight\n      const quadrants: Record<string, DebugEdgeLabelEntry[]> = {\n        topLeft: [],\n        topRight: [],\n        bottomLeft: [],\n        bottomRight: [],\n      };\n\n      debugEdgeLabels.forEach((edge) => {\n        const key =\n          (edge.isTop ? \"top\" : \"bottom\") + (edge.isLeft ? \"Left\" : \"Right\");\n        quadrants[key].push(edge);\n      });\n\n      // Render each quadrant\n      Object.entries(quadrants).forEach(([quadrant, quadrantEdges]) => {\n        if (quadrantEdges.length === 0) return;\n\n        const isLeft = quadrant.includes(\"Left\");\n        const isTop = quadrant.includes(\"top\");\n\n        let baseX = isLeft ? padding : width - padding;\n        let baseY = isTop ? padding : height - padding;\n        const textAnchor = isLeft ? \"start\" : \"end\";\n\n        let currentY = baseY;\n\n        quadrantEdges.forEach((edge) => {\n          edge.connections.forEach((conn) => {\n            const arrow = getArrow(conn.from, conn.to);\n            const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;\n            debugLabelsGroup\n              .append(\"text\")\n              .attr(\"x\", baseX)\n              .attr(\"y\", currentY)\n              .attr(\"text-anchor\", textAnchor)\n              .attr(\"dominant-baseline\", isTop ? \"hanging\" : \"auto\")\n              .attr(\"font-size\", fontSize)\n              .attr(\"font-family\", \"SF Mono, monospace\")\n              .attr(\n                \"fill\",\n                conn.missingIface\n                  ? \"rgba(248,113,113,0.9)\"\n                  : \"rgba(255,255,255,0.85)\",\n              )\n              .text(label);\n            currentY += isTop ? lineHeight : -lineHeight;\n          });\n        });\n      });\n    }\n\n    // Draw nodes\n    const nodesGroup = svg.append(\"g\").attr(\"class\", \"nodes-group\");\n\n    nodesWithPositions.forEach((nodeInfo) => {\n      const node = nodeInfo.data;\n      const macmon = node.macmon_info;\n      const modelId = node.system_info?.model_id || \"Unknown\";\n      const friendlyName = node.friendly_name || modelId;\n\n      let ramUsagePercent = 0;\n      let gpuTemp = NaN;\n      let ramTotal = 0;\n      let ramUsed = 0;\n      let gpuUsagePercent = 0;\n      let sysPower: number | null = null;\n\n      if (macmon) {\n        if (macmon.memory && macmon.memory.ram_total > 0) {\n          ramUsagePercent =\n            (macmon.memory.ram_usage / macmon.memory.ram_total) * 100;\n          ramTotal = macmon.memory.ram_total;\n          ramUsed = macmon.memory.ram_usage;\n        }\n        if (macmon.temp && typeof macmon.temp.gpu_temp_avg === \"number\") {\n          gpuTemp = Math.max(30, macmon.temp.gpu_temp_avg);\n        }\n        if (macmon.gpu_usage) {\n          gpuUsagePercent = macmon.gpu_usage[1] * 100;\n        }\n        if (macmon.sys_power) {\n          sysPower = macmon.sys_power;\n        }\n      }\n\n      let iconBaseWidth = nodeRadius * 1.2;\n      let iconBaseHeight = nodeRadius * 1.0;\n      const clipPathId = `clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, \"-\")}`;\n\n      const modelLower = modelId.toLowerCase();\n\n      // Check node states for styling\n      const isHighlighted = highlightedNodes.has(nodeInfo.id);\n      const isInFilter =\n        filteredNodes.size > 0 && filteredNodes.has(nodeInfo.id);\n      const isFilteredOut =\n        filteredNodes.size > 0 && !filteredNodes.has(nodeInfo.id);\n      const isHovered = hoveredNodeId === nodeInfo.id && !isInFilter;\n\n      // Holographic wireframe colors - bright yellow for filter, subtle yellow for hover, grey for filtered out\n      const wireColor = isInFilter\n        ? \"rgba(255,215,0,1)\" // Bright yellow for filter selection\n        : isHovered\n          ? \"rgba(255,215,0,0.7)\" // Subtle yellow for hover\n          : isHighlighted\n            ? \"rgba(255,215,0,0.9)\" // Yellow for instance highlight\n            : isFilteredOut\n              ? \"rgba(140,140,140,0.6)\" // Grey for filtered out\n              : \"rgba(179,179,179,0.8)\"; // Default\n      const wireColorBright = \"rgba(255,255,255,0.9)\";\n      const fillColor = isInFilter\n        ? \"rgba(255,215,0,0.25)\"\n        : isHovered\n          ? \"rgba(255,215,0,0.12)\"\n          : isHighlighted\n            ? \"rgba(255,215,0,0.15)\"\n            : \"rgba(255,215,0,0.08)\";\n      const strokeWidth = isInFilter\n        ? 3\n        : isHovered\n          ? 2\n          : isHighlighted\n            ? 2.5\n            : 1.5;\n      const screenFill = \"rgba(0,20,40,0.9)\";\n      const glowColor = \"rgba(255,215,0,0.3)\";\n\n      const nodeG = nodesGroup\n        .append(\"g\")\n        .attr(\"class\", \"graph-node\")\n        .style(\"cursor\", onNodeClick ? \"pointer\" : \"default\")\n        .style(\"opacity\", isFilteredOut ? 0.5 : 1);\n\n      // Add click and hover handlers - hover just updates state, styling is applied during render\n      nodeG\n        .on(\"click\", (event: MouseEvent) => {\n          if (onNodeClick) {\n            event.stopPropagation();\n            onNodeClick(nodeInfo.id);\n          }\n        })\n        .on(\"mouseenter\", () => {\n          if (onNodeClick) {\n            hoveredNodeId = nodeInfo.id;\n          }\n        })\n        .on(\"mouseleave\", () => {\n          if (hoveredNodeId === nodeInfo.id) {\n            hoveredNodeId = null;\n          }\n        });\n\n      // Add tooltip\n      nodeG\n        .append(\"title\")\n        .text(\n          `${friendlyName}\\nID: ${nodeInfo.id.slice(-8)}\\nMemory: ${formatBytes(ramUsed)}/${formatBytes(ramTotal)}`,\n        );\n\n      if (modelLower === \"mac studio\") {\n        // Mac Studio - classic cube with memory fill\n        iconBaseWidth = nodeRadius * 1.25;\n        iconBaseHeight = nodeRadius * 0.85;\n        const x = nodeInfo.x - iconBaseWidth / 2;\n        const y = nodeInfo.y - iconBaseHeight / 2;\n        const cornerRadius = 4;\n        const topSurfaceHeight = iconBaseHeight * 0.15;\n\n        // Create clip path for memory fill area (front body)\n        const studioClipId = `studio-clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, \"-\")}`;\n        defs\n          .append(\"clipPath\")\n          .attr(\"id\", studioClipId)\n          .append(\"rect\")\n          .attr(\"x\", x)\n          .attr(\"y\", y + topSurfaceHeight)\n          .attr(\"width\", iconBaseWidth)\n          .attr(\"height\", iconBaseHeight - topSurfaceHeight)\n          .attr(\"rx\", cornerRadius - 1);\n\n        // Main body (uniform color)\n        nodeG\n          .append(\"rect\")\n          .attr(\"class\", \"node-outline\")\n          .attr(\"x\", x)\n          .attr(\"y\", y)\n          .attr(\"width\", iconBaseWidth)\n          .attr(\"height\", iconBaseHeight)\n          .attr(\"rx\", cornerRadius)\n          .attr(\"fill\", \"#1a1a1a\")\n          .attr(\"stroke\", wireColor)\n          .attr(\"stroke-width\", strokeWidth);\n\n        // Memory fill (fills from bottom up)\n        if (ramUsagePercent > 0) {\n          const memFillTotalHeight = iconBaseHeight - topSurfaceHeight;\n          const memFillActualHeight =\n            (ramUsagePercent / 100) * memFillTotalHeight;\n          nodeG\n            .append(\"rect\")\n            .attr(\"x\", x)\n            .attr(\n              \"y\",\n              y + topSurfaceHeight + (memFillTotalHeight - memFillActualHeight),\n            )\n            .attr(\"width\", iconBaseWidth)\n            .attr(\"height\", memFillActualHeight)\n            .attr(\"fill\", \"rgba(255,215,0,0.75)\")\n            .attr(\"clip-path\", `url(#${studioClipId})`);\n        }\n\n        // Front panel details - vertical slots\n        const detailColor = \"rgba(0,0,0,0.35)\";\n        const slotHeight = iconBaseHeight * 0.14;\n        const vSlotWidth = iconBaseWidth * 0.05;\n        const vSlotY =\n          y + topSurfaceHeight + (iconBaseHeight - topSurfaceHeight) * 0.6;\n        const vSlot1X = x + iconBaseWidth * 0.18;\n        const vSlot2X = x + iconBaseWidth * 0.28;\n\n        [vSlot1X, vSlot2X].forEach((vx) => {\n          nodeG\n            .append(\"rect\")\n            .attr(\"x\", vx - vSlotWidth / 2)\n            .attr(\"y\", vSlotY)\n            .attr(\"width\", vSlotWidth)\n            .attr(\"height\", slotHeight)\n            .attr(\"fill\", detailColor)\n            .attr(\"rx\", 1.5);\n        });\n\n        // Horizontal slot (SD card)\n        const hSlotWidth = iconBaseWidth * 0.2;\n        const hSlotX = x + iconBaseWidth * 0.5 - hSlotWidth / 2;\n        nodeG\n          .append(\"rect\")\n          .attr(\"x\", hSlotX)\n          .attr(\"y\", vSlotY)\n          .attr(\"width\", hSlotWidth)\n          .attr(\"height\", slotHeight * 0.6)\n          .attr(\"fill\", detailColor)\n          .attr(\"rx\", 1);\n      } else if (modelLower === \"mac mini\") {\n        // Mac Mini - classic flat box with memory fill\n        iconBaseWidth = nodeRadius * 1.3;\n        iconBaseHeight = nodeRadius * 0.7;\n        const x = nodeInfo.x - iconBaseWidth / 2;\n        const y = nodeInfo.y - iconBaseHeight / 2;\n        const cornerRadius = 3;\n        const topSurfaceHeight = iconBaseHeight * 0.2;\n\n        // Create clip path for memory fill area\n        const miniClipId = `mini-clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, \"-\")}`;\n        defs\n          .append(\"clipPath\")\n          .attr(\"id\", miniClipId)\n          .append(\"rect\")\n          .attr(\"x\", x)\n          .attr(\"y\", y + topSurfaceHeight)\n          .attr(\"width\", iconBaseWidth)\n          .attr(\"height\", iconBaseHeight - topSurfaceHeight)\n          .attr(\"rx\", cornerRadius - 1);\n\n        // Main body (uniform color)\n        nodeG\n          .append(\"rect\")\n          .attr(\"class\", \"node-outline\")\n          .attr(\"x\", x)\n          .attr(\"y\", y)\n          .attr(\"width\", iconBaseWidth)\n          .attr(\"height\", iconBaseHeight)\n          .attr(\"rx\", cornerRadius)\n          .attr(\"fill\", \"#1a1a1a\")\n          .attr(\"stroke\", wireColor)\n          .attr(\"stroke-width\", strokeWidth);\n\n        // Memory fill (fills from bottom up)\n        if (ramUsagePercent > 0) {\n          const memFillTotalHeight = iconBaseHeight - topSurfaceHeight;\n          const memFillActualHeight =\n            (ramUsagePercent / 100) * memFillTotalHeight;\n          nodeG\n            .append(\"rect\")\n            .attr(\"x\", x)\n            .attr(\n              \"y\",\n              y + topSurfaceHeight + (memFillTotalHeight - memFillActualHeight),\n            )\n            .attr(\"width\", iconBaseWidth)\n            .attr(\"height\", memFillActualHeight)\n            .attr(\"fill\", \"rgba(255,215,0,0.75)\")\n            .attr(\"clip-path\", `url(#${miniClipId})`);\n        }\n\n        // Front panel details - vertical slots (no horizontal slot for Mini)\n        const detailColor = \"rgba(0,0,0,0.35)\";\n        const slotHeight = iconBaseHeight * 0.2;\n        const vSlotWidth = iconBaseWidth * 0.045;\n        const vSlotY =\n          y + topSurfaceHeight + (iconBaseHeight - topSurfaceHeight) * 0.45;\n        const vSlot1X = x + iconBaseWidth * 0.2;\n        const vSlot2X = x + iconBaseWidth * 0.3;\n\n        [vSlot1X, vSlot2X].forEach((vx) => {\n          nodeG\n            .append(\"rect\")\n            .attr(\"x\", vx - vSlotWidth / 2)\n            .attr(\"y\", vSlotY)\n            .attr(\"width\", vSlotWidth)\n            .attr(\"height\", slotHeight)\n            .attr(\"fill\", detailColor)\n            .attr(\"rx\", 1.2);\n        });\n      } else if (\n        modelLower === \"macbook pro\" ||\n        modelLower.includes(\"macbook\")\n      ) {\n        // MacBook Pro - classic style with memory fill on screen\n        iconBaseWidth = nodeRadius * 1.6;\n        iconBaseHeight = nodeRadius * 1.15;\n        const x = nodeInfo.x - iconBaseWidth / 2;\n        const y = nodeInfo.y - iconBaseHeight / 2;\n\n        const screenHeight = iconBaseHeight * 0.7;\n        const baseHeight = iconBaseHeight * 0.3;\n        const screenWidth = iconBaseWidth * 0.85;\n        const screenX = nodeInfo.x - screenWidth / 2;\n        const screenBezel = 3;\n\n        // Create clip path for screen content\n        const screenClipId = `screen-clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, \"-\")}`;\n        defs\n          .append(\"clipPath\")\n          .attr(\"id\", screenClipId)\n          .append(\"rect\")\n          .attr(\"x\", screenX + screenBezel)\n          .attr(\"y\", y + screenBezel)\n          .attr(\"width\", screenWidth - screenBezel * 2)\n          .attr(\"height\", screenHeight - screenBezel * 2)\n          .attr(\"rx\", 2);\n\n        // Screen outer frame\n        nodeG\n          .append(\"rect\")\n          .attr(\"class\", \"node-outline\")\n          .attr(\"x\", screenX)\n          .attr(\"y\", y)\n          .attr(\"width\", screenWidth)\n          .attr(\"height\", screenHeight)\n          .attr(\"rx\", 3)\n          .attr(\"fill\", \"#1a1a1a\")\n          .attr(\"stroke\", wireColor)\n          .attr(\"stroke-width\", strokeWidth);\n\n        // Screen inner (dark background)\n        nodeG\n          .append(\"rect\")\n          .attr(\"x\", screenX + screenBezel)\n          .attr(\"y\", y + screenBezel)\n          .attr(\"width\", screenWidth - screenBezel * 2)\n          .attr(\"height\", screenHeight - screenBezel * 2)\n          .attr(\"rx\", 2)\n          .attr(\"fill\", \"#0a0a12\");\n\n        // Memory fill on screen (fills from bottom up - classic style)\n        if (ramUsagePercent > 0) {\n          const memFillTotalHeight = screenHeight - screenBezel * 2;\n          const memFillActualHeight =\n            (ramUsagePercent / 100) * memFillTotalHeight;\n          nodeG\n            .append(\"rect\")\n            .attr(\"x\", screenX + screenBezel)\n            .attr(\n              \"y\",\n              y + screenBezel + (memFillTotalHeight - memFillActualHeight),\n            )\n            .attr(\"width\", screenWidth - screenBezel * 2)\n            .attr(\"height\", memFillActualHeight)\n            .attr(\"fill\", \"rgba(255,215,0,0.85)\")\n            .attr(\"clip-path\", `url(#${screenClipId})`);\n        }\n\n        // Apple logo on screen (centered, on top of memory fill)\n        const targetLogoHeight = screenHeight * 0.22;\n        const logoScale = targetLogoHeight / LOGO_NATIVE_HEIGHT;\n        const logoX = nodeInfo.x - (LOGO_NATIVE_WIDTH * logoScale) / 2;\n        const logoY =\n          y + screenHeight / 2 - (LOGO_NATIVE_HEIGHT * logoScale) / 2;\n        nodeG\n          .append(\"path\")\n          .attr(\"d\", APPLE_LOGO_PATH)\n          .attr(\n            \"transform\",\n            `translate(${logoX}, ${logoY}) scale(${logoScale})`,\n          )\n          .attr(\"fill\", \"#FFFFFF\")\n          .attr(\"opacity\", 0.9);\n\n        // Base (keyboard) - trapezoidal\n        const baseY = y + screenHeight;\n        const baseTopWidth = screenWidth;\n        const baseBottomWidth = iconBaseWidth;\n        const baseTopX = nodeInfo.x - baseTopWidth / 2;\n        const baseBottomX = nodeInfo.x - baseBottomWidth / 2;\n\n        nodeG\n          .append(\"path\")\n          .attr(\n            \"d\",\n            `M ${baseTopX} ${baseY} L ${baseTopX + baseTopWidth} ${baseY} L ${baseBottomX + baseBottomWidth} ${baseY + baseHeight} L ${baseBottomX} ${baseY + baseHeight} Z`,\n          )\n          .attr(\"fill\", \"#2c2c2c\")\n          .attr(\"stroke\", wireColor)\n          .attr(\"stroke-width\", 1);\n\n        // Keyboard area\n        const keyboardX = baseTopX + 6;\n        const keyboardY = baseY + 3;\n        const keyboardWidth = baseTopWidth - 12;\n        const keyboardHeight = baseHeight * 0.55;\n        nodeG\n          .append(\"rect\")\n          .attr(\"x\", keyboardX)\n          .attr(\"y\", keyboardY)\n          .attr(\"width\", keyboardWidth)\n          .attr(\"height\", keyboardHeight)\n          .attr(\"fill\", \"rgba(0,0,0,0.2)\")\n          .attr(\"rx\", 2);\n\n        // Trackpad\n        const trackpadWidth = baseTopWidth * 0.4;\n        const trackpadX = nodeInfo.x - trackpadWidth / 2;\n        const trackpadY = baseY + keyboardHeight + 5;\n        const trackpadHeight = baseHeight * 0.3;\n        nodeG\n          .append(\"rect\")\n          .attr(\"x\", trackpadX)\n          .attr(\"y\", trackpadY)\n          .attr(\"width\", trackpadWidth)\n          .attr(\"height\", trackpadHeight)\n          .attr(\"fill\", \"rgba(255,255,255,0.08)\")\n          .attr(\"rx\", 2);\n      } else {\n        // Default/Unknown - holographic hexagon\n        const hexRadius = nodeRadius * 0.6;\n        const hexPoints = Array.from({ length: 6 }, (_, i) => {\n          const angle = ((i * 60 - 30) * Math.PI) / 180;\n          return `${nodeInfo.x + hexRadius * Math.cos(angle)},${nodeInfo.y + hexRadius * Math.sin(angle)}`;\n        }).join(\" \");\n\n        // Main shape\n        nodeG\n          .append(\"polygon\")\n          .attr(\"class\", \"node-outline\")\n          .attr(\"points\", hexPoints)\n          .attr(\"fill\", fillColor)\n          .attr(\"stroke\", wireColor)\n          .attr(\"stroke-width\", strokeWidth);\n      }\n\n      // --- Vertical GPU Bar (right side of icon) ---\n      // Show in both full mode and minimized mode (scaled appropriately)\n      if (showFullLabels || isMinimized) {\n        const gpuBarWidth = isMinimized\n          ? Math.max(16, nodeRadius * 0.32)\n          : Math.max(28, nodeRadius * 0.3);\n        const gpuBarHeight = iconBaseHeight * 0.95;\n        const barXOffset = iconBaseWidth / 2 + (isMinimized ? 5 : 10);\n        const gpuBarX = nodeInfo.x + barXOffset;\n        const gpuBarY = nodeInfo.y - gpuBarHeight / 2;\n\n        // GPU Bar Background (grey, no border)\n        nodeG\n          .append(\"rect\")\n          .attr(\"x\", gpuBarX)\n          .attr(\"y\", gpuBarY)\n          .attr(\"width\", gpuBarWidth)\n          .attr(\"height\", gpuBarHeight)\n          .attr(\"fill\", \"rgba(80, 80, 90, 0.7)\")\n          .attr(\"rx\", 2);\n\n        // GPU Bar Fill (from bottom up, colored by temperature)\n        if (gpuUsagePercent > 0) {\n          const fillHeight = (gpuUsagePercent / 100) * gpuBarHeight;\n          const gpuFillColor = getTemperatureColor(gpuTemp);\n          nodeG\n            .append(\"rect\")\n            .attr(\"x\", gpuBarX)\n            .attr(\"y\", gpuBarY + (gpuBarHeight - fillHeight))\n            .attr(\"width\", gpuBarWidth)\n            .attr(\"height\", fillHeight)\n            .attr(\"fill\", gpuFillColor)\n            .attr(\"opacity\", 0.9)\n            .attr(\"rx\", 2);\n        }\n\n        // GPU Stats Text (centered on bar, multiline, bigger and bold)\n        const gpuTextX = gpuBarX + gpuBarWidth / 2;\n        const gpuTextY = gpuBarY + gpuBarHeight / 2;\n        const gpuTextFontSize = isMinimized\n          ? Math.max(10, gpuBarWidth * 0.6)\n          : Math.min(16, Math.max(12, gpuBarWidth * 0.55));\n        const lineSpacing = gpuTextFontSize * 1.25;\n\n        const gpuUsageText = `${gpuUsagePercent.toFixed(0)}%`;\n        const tempText = !isNaN(gpuTemp) ? `${gpuTemp.toFixed(0)}°C` : \"-\";\n        const powerText = sysPower !== null ? `${sysPower.toFixed(0)}W` : \"-\";\n\n        // GPU Usage %\n        nodeG\n          .append(\"text\")\n          .attr(\"x\", gpuTextX)\n          .attr(\"y\", gpuTextY - lineSpacing)\n          .attr(\"text-anchor\", \"middle\")\n          .attr(\"dominant-baseline\", \"middle\")\n          .attr(\"fill\", \"#FFFFFF\")\n          .attr(\"font-size\", gpuTextFontSize)\n          .attr(\"font-weight\", \"700\")\n          .attr(\"font-family\", \"SF Mono, Monaco, monospace\")\n          .text(gpuUsageText);\n\n        // Temperature\n        nodeG\n          .append(\"text\")\n          .attr(\"x\", gpuTextX)\n          .attr(\"y\", gpuTextY)\n          .attr(\"text-anchor\", \"middle\")\n          .attr(\"dominant-baseline\", \"middle\")\n          .attr(\"fill\", \"#FFFFFF\")\n          .attr(\"font-size\", gpuTextFontSize)\n          .attr(\"font-weight\", \"700\")\n          .attr(\"font-family\", \"SF Mono, Monaco, monospace\")\n          .text(tempText);\n\n        // Power (Watts)\n        nodeG\n          .append(\"text\")\n          .attr(\"x\", gpuTextX)\n          .attr(\"y\", gpuTextY + lineSpacing)\n          .attr(\"text-anchor\", \"middle\")\n          .attr(\"dominant-baseline\", \"middle\")\n          .attr(\"fill\", \"#FFFFFF\")\n          .attr(\"font-size\", gpuTextFontSize)\n          .attr(\"font-weight\", \"700\")\n          .attr(\"font-family\", \"SF Mono, Monaco, monospace\")\n          .text(powerText);\n      }\n\n      // Labels - adapt based on mode\n      if (showFullLabels) {\n        // FULL MODE: Name above, memory info below (1-4 nodes)\n        const nameY = nodeInfo.y - iconBaseHeight / 2 - 15;\n        const fontSize = Math.max(10, nodeRadius * 0.16);\n\n        // Truncate name based on node count\n        const maxNameLen =\n          numNodes === 1 ? 22 : numNodes === 2 ? 18 : numNodes === 3 ? 16 : 14;\n        const displayName =\n          friendlyName.length > maxNameLen\n            ? friendlyName.slice(0, maxNameLen - 2) + \"..\"\n            : friendlyName;\n\n        // Name label above\n        nodeG\n          .append(\"text\")\n          .attr(\"x\", nodeInfo.x)\n          .attr(\"y\", nameY)\n          .attr(\"text-anchor\", \"middle\")\n          .attr(\"dominant-baseline\", \"middle\")\n          .attr(\"fill\", \"#FFD700\")\n          .attr(\"font-size\", fontSize)\n          .attr(\"font-weight\", 500)\n          .attr(\"font-family\", \"SF Mono, Monaco, monospace\")\n          .text(displayName);\n\n        // Memory info below - used in grey, total in yellow\n        const infoY = nodeInfo.y + iconBaseHeight / 2 + 16;\n        const memText = nodeG\n          .append(\"text\")\n          .attr(\"x\", nodeInfo.x)\n          .attr(\"y\", infoY)\n          .attr(\"text-anchor\", \"middle\")\n          .attr(\"font-size\", fontSize * 0.85)\n          .attr(\"font-family\", \"SF Mono, Monaco, monospace\");\n        memText\n          .append(\"tspan\")\n          .attr(\"fill\", \"rgba(255,215,0,0.9)\")\n          .text(`${formatBytes(ramUsed)}`);\n        memText\n          .append(\"tspan\")\n          .attr(\"fill\", \"rgba(179,179,179,0.9)\")\n          .text(`/${formatBytes(ramTotal)}`);\n        memText\n          .append(\"tspan\")\n          .attr(\"fill\", \"rgba(179,179,179,0.7)\")\n          .text(` (${ramUsagePercent.toFixed(0)}%)`);\n      } else if (showCompactLabels) {\n        // COMPACT MODE: Just name and basic info (4+ nodes)\n        const fontSize = Math.max(7, nodeRadius * 0.11);\n\n        // Very compact name below icon\n        const nameY = nodeInfo.y + iconBaseHeight / 2 + 9;\n        const shortName =\n          friendlyName.length > 10\n            ? friendlyName.slice(0, 8) + \"..\"\n            : friendlyName;\n        nodeG\n          .append(\"text\")\n          .attr(\"x\", nodeInfo.x)\n          .attr(\"y\", nameY)\n          .attr(\"text-anchor\", \"middle\")\n          .attr(\"fill\", \"#FFD700\")\n          .attr(\"font-size\", fontSize)\n          .attr(\"font-family\", \"SF Mono, Monaco, monospace\")\n          .text(shortName);\n\n        // Single line of key stats\n        const statsY = nameY + 9;\n        nodeG\n          .append(\"text\")\n          .attr(\"x\", nodeInfo.x)\n          .attr(\"y\", statsY)\n          .attr(\"text-anchor\", \"middle\")\n          .attr(\"fill\", \"rgba(255,215,0,0.7)\")\n          .attr(\"font-size\", fontSize * 0.85)\n          .attr(\"font-family\", \"SF Mono, Monaco, monospace\")\n          .text(\n            `${ramUsagePercent.toFixed(0)}%${!isNaN(gpuTemp) ? \" \" + gpuTemp.toFixed(0) + \"°C\" : \"\"}`,\n          );\n      } else {\n        // MINIMIZED MODE: Show name above and memory info below (like main topology)\n        const fontSize = 8;\n\n        // Friendly name (shortened) above icon\n        const nameY = nodeInfo.y - iconBaseHeight / 2 - 8;\n        const shortName =\n          friendlyName.length > 12\n            ? friendlyName.slice(0, 10) + \"..\"\n            : friendlyName;\n        nodeG\n          .append(\"text\")\n          .attr(\"x\", nodeInfo.x)\n          .attr(\"y\", nameY)\n          .attr(\"text-anchor\", \"middle\")\n          .attr(\"fill\", \"#FFD700\")\n          .attr(\"font-size\", fontSize)\n          .attr(\"font-weight\", \"500\")\n          .attr(\"font-family\", \"SF Mono, Monaco, monospace\")\n          .text(shortName);\n\n        // Memory info below icon - used in grey, total in yellow (same as main topology)\n        const infoY = nodeInfo.y + iconBaseHeight / 2 + 10;\n        const memTextMini = nodeG\n          .append(\"text\")\n          .attr(\"x\", nodeInfo.x)\n          .attr(\"y\", infoY)\n          .attr(\"text-anchor\", \"middle\")\n          .attr(\"font-size\", fontSize * 0.85)\n          .attr(\"font-family\", \"SF Mono, Monaco, monospace\");\n        memTextMini\n          .append(\"tspan\")\n          .attr(\"fill\", \"rgba(255,215,0,0.9)\")\n          .text(`${formatBytes(ramUsed)}`);\n        memTextMini\n          .append(\"tspan\")\n          .attr(\"fill\", \"rgba(179,179,179,0.9)\")\n          .text(`/${formatBytes(ramTotal)}`);\n        memTextMini\n          .append(\"tspan\")\n          .attr(\"fill\", \"rgba(179,179,179,0.7)\")\n          .text(` (${ramUsagePercent.toFixed(0)}%)`);\n      }\n\n      // Debug mode: Show TB bridge and RDMA status\n      if (debugEnabled) {\n        let debugLabelY =\n          nodeInfo.y +\n          iconBaseHeight / 2 +\n          (showFullLabels ? 32 : showCompactLabels ? 26 : 22);\n        const debugFontSize = showFullLabels ? 9 : 7;\n        const debugLineHeight = showFullLabels ? 11 : 9;\n\n        const tbStatus = tbBridgeData[nodeInfo.id];\n        if (tbStatus) {\n          const tbColor = tbStatus.enabled\n            ? \"rgba(234,179,8,0.9)\"\n            : \"rgba(100,100,100,0.7)\";\n          const tbText = tbStatus.enabled ? \"TB:ON\" : \"TB:OFF\";\n          nodeG\n            .append(\"text\")\n            .attr(\"x\", nodeInfo.x)\n            .attr(\"y\", debugLabelY)\n            .attr(\"text-anchor\", \"middle\")\n            .attr(\"fill\", tbColor)\n            .attr(\"font-size\", debugFontSize)\n            .attr(\"font-family\", \"SF Mono, Monaco, monospace\")\n            .text(tbText);\n          debugLabelY += debugLineHeight;\n        }\n\n        const rdmaStatus = rdmaCtlData[nodeInfo.id];\n        if (rdmaStatus !== undefined) {\n          const rdmaColor = rdmaStatus.enabled\n            ? \"rgba(74,222,128,0.9)\"\n            : \"rgba(100,100,100,0.7)\";\n          const rdmaText = rdmaStatus.enabled ? \"RDMA:ON\" : \"RDMA:OFF\";\n          nodeG\n            .append(\"text\")\n            .attr(\"x\", nodeInfo.x)\n            .attr(\"y\", debugLabelY)\n            .attr(\"text-anchor\", \"middle\")\n            .attr(\"fill\", rdmaColor)\n            .attr(\"font-size\", debugFontSize)\n            .attr(\"font-family\", \"SF Mono, Monaco, monospace\")\n            .text(rdmaText);\n          debugLabelY += debugLineHeight;\n        }\n\n        const identity = identitiesData[nodeInfo.id];\n        if (identity?.osVersion) {\n          nodeG\n            .append(\"text\")\n            .attr(\"x\", nodeInfo.x)\n            .attr(\"y\", debugLabelY)\n            .attr(\"text-anchor\", \"middle\")\n            .attr(\"fill\", \"rgba(179,179,179,0.7)\")\n            .attr(\"font-size\", debugFontSize)\n            .attr(\"font-family\", \"SF Mono, Monaco, monospace\")\n            .text(\n              `macOS ${identity.osVersion}${identity.osBuildVersion ? ` (${identity.osBuildVersion})` : \"\"}`,\n            );\n        }\n      }\n    });\n  }\n\n  $effect(() => {\n    // Track all reactive dependencies that affect rendering\n    const _data = data;\n    const _hoveredNodeId = hoveredNodeId;\n    const _filteredNodes = filteredNodes;\n    const _highlightedNodes = highlightedNodes;\n    if (_data) {\n      renderGraph();\n    }\n  });\n\n  onMount(() => {\n    if (svgContainer) {\n      resizeObserver = new ResizeObserver(() => {\n        renderGraph();\n      });\n      resizeObserver.observe(svgContainer);\n    }\n  });\n\n  onDestroy(() => {\n    resizeObserver?.disconnect();\n  });\n</script>\n\n<svg bind:this={svgContainer} class=\"w-full h-full {className}\"></svg>\n\n<style>\n  :global(.graph-node) {\n    /* Only transition opacity for filtered-out nodes, no transition on hover stroke changes */\n    transition: opacity 0.2s ease;\n  }\n  :global(.graph-link) {\n    stroke: var(--exo-light-gray, #b3b3b3);\n    stroke-width: 1px;\n    stroke-dasharray: 4, 4;\n    opacity: 0.8;\n    animation: flowAnimation 0.75s linear infinite;\n  }\n  @keyframes flowAnimation {\n    from {\n      stroke-dashoffset: 0;\n    }\n    to {\n      stroke-dashoffset: -10;\n    }\n  }\n</style>\n"
  },
  {
    "path": "dashboard/src/lib/components/index.ts",
    "content": "export { default as TopologyGraph } from \"./TopologyGraph.svelte\";\nexport { default as ChatForm } from \"./ChatForm.svelte\";\nexport { default as ChatMessages } from \"./ChatMessages.svelte\";\nexport { default as ChatAttachments } from \"./ChatAttachments.svelte\";\nexport { default as ChatSidebar } from \"./ChatSidebar.svelte\";\nexport { default as ModelCard } from \"./ModelCard.svelte\";\nexport { default as MarkdownContent } from \"./MarkdownContent.svelte\";\nexport { default as ImageParamsPanel } from \"./ImageParamsPanel.svelte\";\nexport { default as FamilyLogos } from \"./FamilyLogos.svelte\";\nexport { default as FamilySidebar } from \"./FamilySidebar.svelte\";\nexport { default as HuggingFaceResultItem } from \"./HuggingFaceResultItem.svelte\";\nexport { default as ModelFilterPopover } from \"./ModelFilterPopover.svelte\";\nexport { default as ModelPickerGroup } from \"./ModelPickerGroup.svelte\";\nexport { default as ModelPickerModal } from \"./ModelPickerModal.svelte\";\nexport { default as ChatModelSelector } from \"./ChatModelSelector.svelte\";\n"
  },
  {
    "path": "dashboard/src/lib/stores/app.svelte.ts",
    "content": "/**\n * AppStore - Central state management for the EXO dashboard\n *\n * Manages:\n * - Chat state (whether a conversation has started)\n * - Topology data from the EXO server\n * - UI state for the topology/chat transition\n */\n\nimport { browser } from \"$app/environment\";\n\n// UUID generation fallback for browsers without crypto.randomUUID\nfunction generateUUID(): string {\n  if (\n    typeof crypto !== \"undefined\" &&\n    typeof crypto.randomUUID === \"function\"\n  ) {\n    return crypto.randomUUID();\n  }\n  // Fallback implementation\n  return \"xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx\".replace(/[xy]/g, (c) => {\n    const r = (Math.random() * 16) | 0;\n    const v = c === \"x\" ? r : (r & 0x3) | 0x8;\n    return v.toString(16);\n  });\n}\n\nexport interface NodeInfo {\n  system_info?: {\n    model_id?: string;\n    chip?: string;\n    memory?: number;\n  };\n  network_interfaces?: Array<{\n    name?: string;\n    addresses?: string[];\n  }>;\n  ip_to_interface?: Record<string, string>;\n  macmon_info?: {\n    memory?: {\n      ram_usage: number;\n      ram_total: number;\n    };\n    temp?: {\n      gpu_temp_avg: number;\n    };\n    gpu_usage?: [number, number];\n    sys_power?: number;\n  };\n  last_macmon_update: number;\n  friendly_name?: string;\n  os_version?: string;\n}\n\nexport interface TopologyEdge {\n  source: string;\n  target: string;\n  sendBackIp?: string;\n  sendBackInterface?: string;\n  sourceRdmaIface?: string;\n  sinkRdmaIface?: string;\n}\n\nexport interface TopologyData {\n  nodes: Record<string, NodeInfo>;\n  edges: TopologyEdge[];\n}\n\nexport interface Instance {\n  shardAssignments?: {\n    modelId?: string;\n    runnerToShard?: Record<string, unknown>;\n    nodeToRunner?: Record<string, string>;\n  };\n}\n\n// Granular node state types from the new state structure\ninterface RawNodeIdentity {\n  modelId?: string;\n  chipId?: string;\n  friendlyName?: string;\n  osVersion?: string;\n  osBuildVersion?: string;\n}\n\ninterface RawMemoryUsage {\n  ramTotal?: { inBytes: number };\n  ramAvailable?: { inBytes: number };\n  swapTotal?: { inBytes: number };\n  swapAvailable?: { inBytes: number };\n}\n\ninterface RawSystemPerformanceProfile {\n  gpuUsage?: number;\n  temp?: number;\n  sysPower?: number;\n  pcpuUsage?: number;\n  ecpuUsage?: number;\n}\n\ninterface RawNetworkInterfaceInfo {\n  name?: string;\n  ipAddress?: string;\n  addresses?: Array<{ address?: string } | string>;\n  ipv4?: string;\n  ipv6?: string;\n  ipAddresses?: string[];\n  ips?: string[];\n}\n\ninterface RawNodeNetworkInfo {\n  interfaces?: RawNetworkInterfaceInfo[];\n}\n\ninterface RawSocketConnection {\n  sinkMultiaddr?: {\n    address?: string;\n    ip_address?: string;\n    address_type?: string;\n    port?: number;\n  };\n}\n\ninterface RawRDMAConnection {\n  sourceRdmaIface?: string;\n  sinkRdmaIface?: string;\n}\n\ntype RawConnectionEdge = RawSocketConnection | RawRDMAConnection;\n\n// New nested mapping format: { source: { sink: [edge1, edge2, ...] } }\ntype RawConnectionsMap = Record<string, Record<string, RawConnectionEdge[]>>;\n\ninterface RawTopology {\n  nodes: string[];\n  connections?: RawConnectionsMap;\n}\n\nexport interface DownloadProgress {\n  totalBytes: number;\n  downloadedBytes: number;\n  speed: number;\n  etaMs: number;\n  percentage: number;\n  completedFiles: number;\n  totalFiles: number;\n  files: Array<{\n    name: string;\n    totalBytes: number;\n    downloadedBytes: number;\n    speed: number;\n    etaMs: number;\n    percentage: number;\n  }>;\n}\n\nexport interface ModelDownloadStatus {\n  isDownloading: boolean;\n  progress: DownloadProgress | null;\n  nodeDetails: Array<{\n    nodeId: string;\n    nodeName: string;\n    progress: DownloadProgress;\n  }>;\n}\n\n// Placement preview from the API\nexport interface PlacementPreview {\n  model_id: string;\n  sharding: \"Pipeline\" | \"Tensor\";\n  instance_meta: \"MlxRing\" | \"MlxJaccl\";\n  instance: unknown | null;\n  memory_delta_by_node: Record<string, number> | null;\n  error: string | null;\n}\n\nexport interface PlacementPreviewResponse {\n  previews: PlacementPreview[];\n}\n\ninterface ImageApiResponse {\n  created: number;\n  data: Array<{ b64_json?: string; url?: string }>;\n}\n\n// Trace API response types\nexport interface TraceCategoryStats {\n  totalUs: number;\n  count: number;\n  minUs: number;\n  maxUs: number;\n  avgUs: number;\n}\n\nexport interface TraceRankStats {\n  byCategory: Record<string, TraceCategoryStats>;\n}\n\nexport interface TraceStatsResponse {\n  taskId: string;\n  totalWallTimeUs: number;\n  byCategory: Record<string, TraceCategoryStats>;\n  byRank: Record<number, TraceRankStats>;\n}\n\nexport interface TraceListItem {\n  taskId: string;\n  createdAt: string;\n  fileSize: number;\n}\n\nexport interface TraceListResponse {\n  traces: TraceListItem[];\n}\n\ninterface RawStateResponse {\n  topology?: RawTopology;\n  instances?: Record<\n    string,\n    {\n      MlxRingInstance?: Instance;\n      MlxJacclInstance?: Instance;\n    }\n  >;\n  runners?: Record<string, unknown>;\n  downloads?: Record<string, unknown[]>;\n  // New granular node state fields\n  nodeIdentities?: Record<string, RawNodeIdentity>;\n  nodeMemory?: Record<string, RawMemoryUsage>;\n  nodeSystem?: Record<string, RawSystemPerformanceProfile>;\n  nodeNetwork?: Record<string, RawNodeNetworkInfo>;\n  // Thunderbolt identifiers per node\n  nodeThunderbolt?: Record<\n    string,\n    {\n      interfaces: Array<{\n        rdmaInterface: string;\n        domainUuid: string;\n        linkSpeed: string;\n      }>;\n    }\n  >;\n  // RDMA ctl status per node\n  nodeRdmaCtl?: Record<string, { enabled: boolean }>;\n  // Thunderbolt bridge status per node\n  nodeThunderboltBridge?: Record<\n    string,\n    { enabled: boolean; exists: boolean; serviceName?: string | null }\n  >;\n  // Thunderbolt bridge cycles (nodes with bridge enabled forming loops)\n  thunderboltBridgeCycles?: string[][];\n  // Disk usage per node\n  nodeDisk?: Record<\n    string,\n    { total: { inBytes: number }; available: { inBytes: number } }\n  >;\n}\n\nexport interface MessageAttachment {\n  type: \"image\" | \"text\" | \"file\" | \"generated-image\";\n  name: string;\n  content?: string;\n  preview?: string;\n  mimeType?: string;\n}\n\nexport interface TopLogprob {\n  token: string;\n  logprob: number;\n  bytes: number[] | null;\n}\n\nexport interface TokenData {\n  token: string;\n  logprob: number;\n  probability: number;\n  topLogprobs: TopLogprob[];\n}\n\nexport interface PrefillProgress {\n  processed: number;\n  total: number;\n  /** Timestamp (performance.now()) when prefill started. */\n  startedAt: number;\n}\n\nexport interface Message {\n  id: string;\n  role: \"user\" | \"assistant\" | \"system\";\n  content: string;\n  timestamp: number;\n  thinking?: string;\n  attachments?: MessageAttachment[];\n  ttftMs?: number; // Time to first token in ms (for assistant messages)\n  tps?: number; // Tokens per second (for assistant messages)\n  requestType?: \"chat\" | \"image-generation\" | \"image-editing\";\n  sourceImageDataUrl?: string; // For image editing regeneration\n  tokens?: TokenData[];\n}\n\nexport interface Conversation {\n  id: string;\n  name: string;\n  messages: Message[];\n  createdAt: number;\n  updatedAt: number;\n  modelId: string | null;\n  sharding: string | null;\n  instanceType: string | null;\n  enableThinking: boolean | null;\n}\n\nconst STORAGE_KEY = \"exo-conversations\";\nconst IMAGE_PARAMS_STORAGE_KEY = \"exo-image-generation-params\";\n\n// Image generation params interface matching backend API\nexport interface ImageGenerationParams {\n  // Basic params\n  size:\n    | \"auto\"\n    | \"512x512\"\n    | \"768x768\"\n    | \"1024x1024\"\n    | \"1024x768\"\n    | \"768x1024\"\n    | \"1024x1536\"\n    | \"1536x1024\";\n  quality: \"low\" | \"medium\" | \"high\";\n  outputFormat: \"png\" | \"jpeg\";\n  numImages: number;\n  // Streaming params\n  stream: boolean;\n  partialImages: number;\n  // Advanced params\n  seed: number | null;\n  numInferenceSteps: number | null;\n  guidance: number | null;\n  negativePrompt: string | null;\n  numSyncSteps: number | null;\n  // Edit mode params\n  inputFidelity: \"low\" | \"high\";\n}\n\n// Image being edited\nexport interface EditingImage {\n  imageDataUrl: string;\n  sourceMessage: Message;\n}\n\nconst DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {\n  size: \"auto\",\n  quality: \"medium\",\n  outputFormat: \"png\",\n  numImages: 1,\n  stream: true,\n  partialImages: 3,\n  seed: null,\n  numInferenceSteps: null,\n  guidance: null,\n  negativePrompt: null,\n  numSyncSteps: null,\n  inputFidelity: \"low\",\n};\n\ninterface GranularNodeState {\n  nodeIdentities?: Record<string, RawNodeIdentity>;\n  nodeMemory?: Record<string, RawMemoryUsage>;\n  nodeSystem?: Record<string, RawSystemPerformanceProfile>;\n  nodeNetwork?: Record<string, RawNodeNetworkInfo>;\n}\n\nfunction transformNetworkInterface(iface: RawNetworkInterfaceInfo): {\n  name?: string;\n  addresses: string[];\n} {\n  const addresses: string[] = [];\n  if (iface.ipAddress && typeof iface.ipAddress === \"string\") {\n    addresses.push(iface.ipAddress);\n  }\n  if (Array.isArray(iface.addresses)) {\n    for (const addr of iface.addresses) {\n      if (typeof addr === \"string\") addresses.push(addr);\n      else if (addr && typeof addr === \"object\" && addr.address)\n        addresses.push(addr.address);\n    }\n  }\n  if (Array.isArray(iface.ipAddresses)) {\n    addresses.push(\n      ...iface.ipAddresses.filter((a): a is string => typeof a === \"string\"),\n    );\n  }\n  if (Array.isArray(iface.ips)) {\n    addresses.push(\n      ...iface.ips.filter((a): a is string => typeof a === \"string\"),\n    );\n  }\n  if (iface.ipv4 && typeof iface.ipv4 === \"string\") addresses.push(iface.ipv4);\n  if (iface.ipv6 && typeof iface.ipv6 === \"string\") addresses.push(iface.ipv6);\n\n  return {\n    name: iface.name,\n    addresses: Array.from(new Set(addresses)),\n  };\n}\n\nfunction transformTopology(\n  raw: RawTopology,\n  granularState: GranularNodeState,\n): TopologyData {\n  const nodes: Record<string, NodeInfo> = {};\n  const edges: TopologyEdge[] = [];\n\n  for (const nodeId of raw.nodes || []) {\n    if (!nodeId) continue;\n\n    // Get data from granular state mappings\n    const identity = granularState.nodeIdentities?.[nodeId];\n    const memory = granularState.nodeMemory?.[nodeId];\n    const system = granularState.nodeSystem?.[nodeId];\n    const network = granularState.nodeNetwork?.[nodeId];\n\n    const ramTotal = memory?.ramTotal?.inBytes ?? 0;\n    const ramAvailable = memory?.ramAvailable?.inBytes ?? 0;\n    const ramUsage = Math.max(ramTotal - ramAvailable, 0);\n\n    const rawInterfaces = network?.interfaces || [];\n    const networkInterfaces = rawInterfaces.map(transformNetworkInterface);\n\n    const ipToInterface: Record<string, string> = {};\n    for (const iface of networkInterfaces) {\n      for (const addr of iface.addresses || []) {\n        ipToInterface[addr] = iface.name ?? \"\";\n      }\n    }\n\n    nodes[nodeId] = {\n      system_info: {\n        model_id: identity?.modelId ?? \"Unknown\",\n        chip: identity?.chipId,\n        memory: ramTotal,\n      },\n      network_interfaces: networkInterfaces,\n      ip_to_interface: ipToInterface,\n      macmon_info: {\n        memory: {\n          ram_usage: ramUsage,\n          ram_total: ramTotal,\n        },\n        temp:\n          system?.temp !== undefined\n            ? { gpu_temp_avg: system.temp }\n            : undefined,\n        gpu_usage:\n          system?.gpuUsage !== undefined ? [0, system.gpuUsage] : undefined,\n        sys_power: system?.sysPower,\n      },\n      last_macmon_update: Date.now() / 1000,\n      friendly_name: identity?.friendlyName,\n      os_version: identity?.osVersion,\n    };\n  }\n\n  // Handle connections - nested mapping format { source: { sink: [edges] } }\n  const connections = raw.connections;\n  if (connections && typeof connections === \"object\") {\n    for (const [source, sinks] of Object.entries(connections)) {\n      if (!sinks || typeof sinks !== \"object\") continue;\n      for (const [sink, edgeList] of Object.entries(sinks)) {\n        if (!Array.isArray(edgeList)) continue;\n        for (const edge of edgeList) {\n          let sendBackIp: string | undefined;\n          let sourceRdmaIface: string | undefined;\n          let sinkRdmaIface: string | undefined;\n          if (edge && typeof edge === \"object\" && \"sinkMultiaddr\" in edge) {\n            const multiaddr = edge.sinkMultiaddr;\n            if (multiaddr) {\n              sendBackIp =\n                multiaddr.ip_address ||\n                extractIpFromMultiaddr(multiaddr.address);\n            }\n          } else if (\n            edge &&\n            typeof edge === \"object\" &&\n            \"sourceRdmaIface\" in edge\n          ) {\n            sourceRdmaIface = edge.sourceRdmaIface;\n            sinkRdmaIface = edge.sinkRdmaIface;\n          }\n\n          if (nodes[source] && nodes[sink] && source !== sink) {\n            edges.push({\n              source,\n              target: sink,\n              sendBackIp,\n              sourceRdmaIface,\n              sinkRdmaIface,\n            });\n          }\n        }\n      }\n    }\n  }\n\n  return { nodes, edges };\n}\n\nfunction extractIpFromMultiaddr(ma?: string): string | undefined {\n  if (!ma) return undefined;\n  const parts = ma.split(\"/\");\n  const ip4Idx = parts.indexOf(\"ip4\");\n  const ip6Idx = parts.indexOf(\"ip6\");\n  const idx = ip4Idx >= 0 ? ip4Idx : ip6Idx;\n  if (idx >= 0 && parts.length > idx + 1) {\n    return parts[idx + 1];\n  }\n  return undefined;\n}\n\nclass AppStore {\n  // Conversation state\n  conversations = $state<Conversation[]>([]);\n  activeConversationId = $state<string | null>(null);\n\n  // Chat state\n  hasStartedChat = $state(false);\n  messages = $state<Message[]>([]);\n  currentResponse = $state(\"\");\n  isLoading = $state(false);\n\n  // Performance metrics\n  ttftMs = $state<number | null>(null); // Time to first token in ms\n  tps = $state<number | null>(null); // Tokens per second\n  totalTokens = $state<number>(0); // Total tokens in current response\n  prefillProgress = $state<PrefillProgress | null>(null);\n\n  // Abort controller for stopping generation\n  private currentAbortController: AbortController | null = null;\n\n  // Topology state\n  topologyData = $state<TopologyData | null>(null);\n  instances = $state<Record<string, unknown>>({});\n  runners = $state<Record<string, unknown>>({});\n  downloads = $state<Record<string, unknown[]>>({});\n  nodeDisk = $state<\n    Record<\n      string,\n      { total: { inBytes: number }; available: { inBytes: number } }\n    >\n  >({});\n  placementPreviews = $state<PlacementPreview[]>([]);\n  selectedPreviewModelId = $state<string | null>(null);\n  isLoadingPreviews = $state(false);\n  previewNodeFilter = $state<Set<string>>(new Set());\n  lastUpdate = $state<number | null>(null);\n  nodeIdentities = $state<Record<string, RawNodeIdentity>>({});\n  thunderboltBridgeCycles = $state<string[][]>([]);\n  nodeThunderbolt = $state<\n    Record<\n      string,\n      {\n        interfaces: Array<{\n          rdmaInterface: string;\n          domainUuid: string;\n          linkSpeed: string;\n        }>;\n      }\n    >\n  >({});\n  nodeRdmaCtl = $state<Record<string, { enabled: boolean }>>({});\n  nodeThunderboltBridge = $state<\n    Record<\n      string,\n      { enabled: boolean; exists: boolean; serviceName?: string | null }\n    >\n  >({});\n\n  // UI state\n  isTopologyMinimized = $state(false);\n  isSidebarOpen = $state(false); // Hidden by default, shown when in chat mode\n  debugMode = $state(false);\n  topologyOnlyMode = $state(false);\n  chatSidebarVisible = $state(true); // Shown by default\n  mobileChatSidebarOpen = $state(false); // Mobile drawer state\n  mobileRightSidebarOpen = $state(false); // Mobile right drawer state\n\n  // Image generation params\n  imageGenerationParams = $state<ImageGenerationParams>({\n    ...DEFAULT_IMAGE_PARAMS,\n  });\n\n  // Image editing state\n  editingImage = $state<EditingImage | null>(null);\n\n  /** True when the backend is reachable. */\n  isConnected = $state<boolean>(true);\n  /** Number of consecutive fetch failures. */\n  private consecutiveFailures = 0;\n  private static readonly CONNECTION_LOST_THRESHOLD = 3;\n\n  private fetchInterval: ReturnType<typeof setInterval> | null = null;\n  private previewsInterval: ReturnType<typeof setInterval> | null = null;\n  private lastConversationPersistTs = 0;\n  private previousNodeIds: Set<string> = new Set();\n\n  constructor() {\n    if (browser) {\n      this.startPolling();\n      this.loadConversationsFromStorage();\n      this.loadDebugModeFromStorage();\n      this.loadTopologyOnlyModeFromStorage();\n      this.loadChatSidebarVisibleFromStorage();\n      this.loadImageGenerationParamsFromStorage();\n    }\n  }\n\n  /**\n   * Load conversations from localStorage\n   */\n  private loadConversationsFromStorage() {\n    try {\n      const stored = localStorage.getItem(STORAGE_KEY);\n      if (stored) {\n        const parsed = JSON.parse(stored) as Array<Partial<Conversation>>;\n        this.conversations = parsed.map((conversation) => ({\n          id: conversation.id ?? generateUUID(),\n          name: conversation.name ?? \"Chat\",\n          messages: conversation.messages ?? [],\n          createdAt: conversation.createdAt ?? Date.now(),\n          updatedAt: conversation.updatedAt ?? Date.now(),\n          modelId: conversation.modelId ?? null,\n          sharding: conversation.sharding ?? null,\n          instanceType: conversation.instanceType ?? null,\n          enableThinking: conversation.enableThinking ?? null,\n        }));\n      }\n    } catch (error) {\n      console.error(\"Failed to load conversations:\", error);\n    }\n  }\n\n  /**\n   * Save conversations to localStorage\n   */\n  private saveConversationsToStorage() {\n    try {\n      // Strip tokens from messages before saving to avoid bloating localStorage\n      const stripped = this.conversations.map((conv) => ({\n        ...conv,\n        messages: conv.messages.map((msg) => {\n          if (msg.tokens) {\n            const { tokens: _, ...rest } = msg;\n            return rest;\n          }\n          return msg;\n        }),\n      }));\n      localStorage.setItem(STORAGE_KEY, JSON.stringify(stripped));\n    } catch (error) {\n      console.error(\"Failed to save conversations:\", error);\n    }\n  }\n\n  private loadDebugModeFromStorage() {\n    try {\n      const stored = localStorage.getItem(\"exo-debug-mode\");\n      if (stored !== null) {\n        this.debugMode = stored === \"true\";\n      }\n    } catch (error) {\n      console.error(\"Failed to load debug mode:\", error);\n    }\n  }\n\n  private saveDebugModeToStorage() {\n    try {\n      localStorage.setItem(\"exo-debug-mode\", this.debugMode ? \"true\" : \"false\");\n    } catch (error) {\n      console.error(\"Failed to save debug mode:\", error);\n    }\n  }\n\n  private loadTopologyOnlyModeFromStorage() {\n    try {\n      const stored = localStorage.getItem(\"exo-topology-only-mode\");\n      if (stored !== null) {\n        this.topologyOnlyMode = stored === \"true\";\n      }\n    } catch (error) {\n      console.error(\"Failed to load topology only mode:\", error);\n    }\n  }\n\n  private saveTopologyOnlyModeToStorage() {\n    try {\n      localStorage.setItem(\n        \"exo-topology-only-mode\",\n        this.topologyOnlyMode ? \"true\" : \"false\",\n      );\n    } catch (error) {\n      console.error(\"Failed to save topology only mode:\", error);\n    }\n  }\n\n  private loadChatSidebarVisibleFromStorage() {\n    try {\n      const stored = localStorage.getItem(\"exo-chat-sidebar-visible\");\n      if (stored !== null) {\n        this.chatSidebarVisible = stored === \"true\";\n      }\n    } catch (error) {\n      console.error(\"Failed to load chat sidebar visibility:\", error);\n    }\n  }\n\n  private saveChatSidebarVisibleToStorage() {\n    try {\n      localStorage.setItem(\n        \"exo-chat-sidebar-visible\",\n        this.chatSidebarVisible ? \"true\" : \"false\",\n      );\n    } catch (error) {\n      console.error(\"Failed to save chat sidebar visibility:\", error);\n    }\n  }\n\n  private loadImageGenerationParamsFromStorage() {\n    try {\n      const stored = localStorage.getItem(IMAGE_PARAMS_STORAGE_KEY);\n      if (stored) {\n        const parsed = JSON.parse(stored) as Partial<ImageGenerationParams>;\n        this.imageGenerationParams = {\n          ...DEFAULT_IMAGE_PARAMS,\n          ...parsed,\n        };\n      }\n    } catch (error) {\n      console.error(\"Failed to load image generation params:\", error);\n    }\n  }\n\n  private saveImageGenerationParamsToStorage() {\n    try {\n      localStorage.setItem(\n        IMAGE_PARAMS_STORAGE_KEY,\n        JSON.stringify(this.imageGenerationParams),\n      );\n    } catch (error) {\n      console.error(\"Failed to save image generation params:\", error);\n    }\n  }\n\n  getImageGenerationParams(): ImageGenerationParams {\n    return this.imageGenerationParams;\n  }\n\n  setImageGenerationParams(params: Partial<ImageGenerationParams>) {\n    this.imageGenerationParams = {\n      ...this.imageGenerationParams,\n      ...params,\n    };\n    this.saveImageGenerationParamsToStorage();\n  }\n\n  resetImageGenerationParams() {\n    this.imageGenerationParams = { ...DEFAULT_IMAGE_PARAMS };\n    this.saveImageGenerationParamsToStorage();\n  }\n\n  setEditingImage(imageDataUrl: string, sourceMessage: Message) {\n    this.editingImage = { imageDataUrl, sourceMessage };\n  }\n\n  clearEditingImage() {\n    this.editingImage = null;\n  }\n\n  /**\n   * Create a new conversation\n   */\n  createConversation(name?: string): string {\n    const id = generateUUID();\n    const now = Date.now();\n\n    // Try to derive model and strategy immediately from selected model or running instances\n    let derivedModelId = this.selectedChatModel || null;\n    let derivedInstanceType: string | null = null;\n    let derivedSharding: string | null = null;\n\n    // If no selected model, fall back to the first running instance\n    if (!derivedModelId) {\n      const firstInstance = Object.values(this.instances)[0];\n      if (firstInstance) {\n        const candidateModel = this.extractInstanceModelId(firstInstance);\n        derivedModelId = candidateModel ?? null;\n        const details = this.describeInstance(firstInstance);\n        derivedInstanceType = details.instanceType;\n        derivedSharding = details.sharding;\n      }\n    } else {\n      // If selected model is set, attempt to get its details from instances\n      for (const [, instanceWrapper] of Object.entries(this.instances)) {\n        const candidateModelId = this.extractInstanceModelId(instanceWrapper);\n        if (candidateModelId === derivedModelId) {\n          const details = this.describeInstance(instanceWrapper);\n          derivedInstanceType = details.instanceType;\n          derivedSharding = details.sharding;\n          break;\n        }\n      }\n    }\n\n    const conversation: Conversation = {\n      id,\n      name:\n        name ||\n        `Chat ${new Date(now).toLocaleString(\"en-US\", { month: \"short\", day: \"numeric\", hour: \"2-digit\", minute: \"2-digit\" })}`,\n      messages: [],\n      createdAt: now,\n      updatedAt: now,\n      modelId: derivedModelId,\n      sharding: derivedSharding,\n      instanceType: derivedInstanceType,\n      enableThinking: null,\n    };\n\n    this.conversations.unshift(conversation);\n    this.activeConversationId = id;\n    this.messages = [];\n    this.hasStartedChat = true;\n    this.isTopologyMinimized = true;\n    this.isSidebarOpen = true; // Auto-open sidebar when chatting\n\n    this.saveConversationsToStorage();\n    return id;\n  }\n\n  /**\n   * Load a conversation by ID\n   */\n  loadConversation(id: string): boolean {\n    const conversation = this.conversations.find((c) => c.id === id);\n    if (!conversation) return false;\n\n    this.activeConversationId = id;\n    this.messages = [...conversation.messages];\n    this.hasStartedChat = true;\n    this.isTopologyMinimized = true;\n    this.isSidebarOpen = true; // Auto-open sidebar when chatting\n    this.thinkingEnabled = conversation.enableThinking ?? true;\n    this.refreshConversationModelFromInstances();\n\n    // Sync global selection to the loaded conversation's model so reactive\n    // effects in +page.svelte can determine the correct chat launch state.\n    this.selectedChatModel = conversation.modelId || \"\";\n\n    return true;\n  }\n\n  /**\n   * Delete a conversation by ID\n   */\n  deleteConversation(id: string) {\n    this.conversations = this.conversations.filter((c) => c.id !== id);\n\n    if (this.activeConversationId === id) {\n      this.activeConversationId = null;\n      this.messages = [];\n      this.hasStartedChat = false;\n      this.isTopologyMinimized = false;\n    }\n\n    this.saveConversationsToStorage();\n  }\n\n  /**\n   * Delete all conversations\n   */\n  deleteAllConversations() {\n    this.conversations = [];\n    this.activeConversationId = null;\n    this.messages = [];\n    this.hasStartedChat = false;\n    this.isTopologyMinimized = false;\n    this.saveConversationsToStorage();\n  }\n\n  /**\n   * Rename a conversation\n   */\n  renameConversation(id: string, newName: string) {\n    const conversation = this.conversations.find((c) => c.id === id);\n    if (conversation) {\n      conversation.name = newName;\n      conversation.updatedAt = Date.now();\n      this.saveConversationsToStorage();\n    }\n  }\n\n  private getTaggedValue(obj: unknown): [string | null, unknown] {\n    if (!obj || typeof obj !== \"object\") return [null, null];\n    const keys = Object.keys(obj as Record<string, unknown>);\n    if (keys.length === 1) {\n      return [keys[0], (obj as Record<string, unknown>)[keys[0]]];\n    }\n    return [null, null];\n  }\n\n  private extractInstanceModelId(instanceWrapped: unknown): string | null {\n    const [, instance] = this.getTaggedValue(instanceWrapped);\n    if (!instance || typeof instance !== \"object\") return null;\n    const inst = instance as { shardAssignments?: { modelId?: string } };\n    return inst.shardAssignments?.modelId ?? null;\n  }\n\n  private describeInstance(instanceWrapped: unknown): {\n    sharding: string | null;\n    instanceType: string | null;\n  } {\n    const [instanceTag, instance] = this.getTaggedValue(instanceWrapped);\n    if (!instance || typeof instance !== \"object\") {\n      return { sharding: null, instanceType: null };\n    }\n\n    let instanceType: string | null = null;\n    if (instanceTag === \"MlxRingInstance\") instanceType = \"MLX Ring\";\n    else if (instanceTag === \"MlxJacclInstance\") instanceType = \"MLX RDMA\";\n\n    let sharding: string | null = null;\n    const inst = instance as {\n      shardAssignments?: { runnerToShard?: Record<string, unknown> };\n    };\n    const runnerToShard = inst.shardAssignments?.runnerToShard || {};\n    const firstShardWrapped = Object.values(runnerToShard)[0];\n    if (firstShardWrapped) {\n      const [shardTag] = this.getTaggedValue(firstShardWrapped);\n      if (shardTag === \"PipelineShardMetadata\") sharding = \"Pipeline\";\n      else if (shardTag === \"TensorShardMetadata\") sharding = \"Tensor\";\n      else if (shardTag === \"PrefillDecodeShardMetadata\")\n        sharding = \"Prefill/Decode\";\n    }\n\n    return { sharding, instanceType };\n  }\n\n  private buildConversationModelInfo(modelId: string): {\n    modelId: string;\n    sharding: string | null;\n    instanceType: string | null;\n  } {\n    let sharding: string | null = null;\n    let instanceType: string | null = null;\n\n    for (const [, instanceWrapper] of Object.entries(this.instances)) {\n      const candidateModelId = this.extractInstanceModelId(instanceWrapper);\n      if (candidateModelId === modelId) {\n        const details = this.describeInstance(instanceWrapper);\n        sharding = details.sharding;\n        instanceType = details.instanceType;\n        break;\n      }\n    }\n\n    return { modelId, sharding, instanceType };\n  }\n\n  private applyConversationModelInfo(info: {\n    modelId: string;\n    sharding: string | null;\n    instanceType: string | null;\n  }) {\n    if (!this.activeConversationId) return;\n    const conversation = this.conversations.find(\n      (c) => c.id === this.activeConversationId,\n    );\n    if (!conversation) return;\n\n    // Keep the first known modelId stable; only backfill if missing\n    if (!conversation.modelId) {\n      conversation.modelId = info.modelId;\n    }\n    conversation.sharding = info.sharding;\n    conversation.instanceType = info.instanceType;\n    this.saveConversationsToStorage();\n  }\n\n  private getModelTail(modelId: string): string {\n    const parts = modelId.split(\"/\");\n    return (parts[parts.length - 1] || modelId).toLowerCase();\n  }\n\n  private isBetterModelId(\n    currentId: string | null,\n    candidateId: string | null,\n  ): boolean {\n    if (!candidateId) return false;\n    if (!currentId) return true;\n    const currentTail = this.getModelTail(currentId);\n    const candidateTail = this.getModelTail(candidateId);\n    return (\n      candidateTail.length > currentTail.length &&\n      candidateTail.startsWith(currentTail)\n    );\n  }\n\n  private refreshConversationModelFromInstances() {\n    if (!this.activeConversationId) return;\n    const conversation = this.conversations.find(\n      (c) => c.id === this.activeConversationId,\n    );\n    if (!conversation) return;\n\n    // Prefer stored model; do not replace it once set. Only backfill when missing.\n    let modelId = conversation.modelId;\n\n    // If missing, try the selected model\n    if (!modelId && this.selectedChatModel) {\n      modelId = this.selectedChatModel;\n    }\n\n    // If still missing, fall back to first instance model\n    if (!modelId) {\n      const firstInstance = Object.values(this.instances)[0];\n      if (firstInstance) {\n        modelId = this.extractInstanceModelId(firstInstance);\n      }\n    }\n\n    if (!modelId) return;\n\n    // If a more specific instance modelId is available (e.g., adds \"-4bit\"), prefer it\n    let preferredModelId = modelId;\n    for (const [, instanceWrapper] of Object.entries(this.instances)) {\n      const candidate = this.extractInstanceModelId(instanceWrapper);\n      if (!candidate) continue;\n      if (candidate === preferredModelId) {\n        break;\n      }\n      if (this.isBetterModelId(preferredModelId, candidate)) {\n        preferredModelId = candidate;\n      }\n    }\n\n    if (this.isBetterModelId(conversation.modelId, preferredModelId)) {\n      conversation.modelId = preferredModelId;\n    }\n\n    const info = this.buildConversationModelInfo(preferredModelId);\n    const hasNewInfo = Boolean(\n      info.sharding || info.instanceType || !conversation.modelId,\n    );\n    if (hasNewInfo) {\n      this.applyConversationModelInfo(info);\n    }\n  }\n\n  getDebugMode(): boolean {\n    return this.debugMode;\n  }\n\n  /**\n   * Update the active conversation with current messages\n   */\n  private updateActiveConversation() {\n    if (!this.activeConversationId) return;\n\n    const conversation = this.conversations.find(\n      (c) => c.id === this.activeConversationId,\n    );\n    if (conversation) {\n      conversation.messages = [...this.messages];\n      conversation.updatedAt = Date.now();\n\n      // Auto-generate name from first user message if still has default name\n      if (conversation.name.startsWith(\"Chat \")) {\n        const firstUserMsg = conversation.messages.find(\n          (m) => m.role === \"user\" && m.content.trim(),\n        );\n        if (firstUserMsg) {\n          // Clean up the content - remove file context markers and whitespace\n          let content = firstUserMsg.content\n            .replace(/\\[File:.*?\\][\\s\\S]*?```[\\s\\S]*?```/g, \"\") // Remove file attachments\n            .trim();\n\n          if (content) {\n            const preview = content.slice(0, 50);\n            conversation.name =\n              preview.length < content.length ? preview + \"...\" : preview;\n          }\n        }\n      }\n\n      this.saveConversationsToStorage();\n    }\n  }\n\n  private persistActiveConversation(throttleMs = 400) {\n    const now = Date.now();\n    if (now - this.lastConversationPersistTs < throttleMs) return;\n    this.lastConversationPersistTs = now;\n    this.updateActiveConversation();\n  }\n\n  /**\n   * Update a message in a specific conversation by ID.\n   * Returns false if conversation or message not found.\n   */\n  private updateConversationMessage(\n    conversationId: string,\n    messageId: string,\n    updater: (message: Message) => void,\n  ): boolean {\n    const conversation = this.conversations.find(\n      (c) => c.id === conversationId,\n    );\n    if (!conversation) return false;\n\n    const message = conversation.messages.find((m) => m.id === messageId);\n    if (!message) return false;\n\n    updater(message);\n    return true;\n  }\n\n  /**\n   * Sync this.messages from the target conversation if it matches the active conversation.\n   */\n  private syncActiveMessagesIfNeeded(conversationId: string): void {\n    if (this.activeConversationId === conversationId) {\n      const conversation = this.conversations.find(\n        (c) => c.id === conversationId,\n      );\n      if (conversation) {\n        this.messages = [...conversation.messages];\n      }\n    }\n  }\n\n  /**\n   * Check if a conversation still exists.\n   */\n  private conversationExists(conversationId: string): boolean {\n    return this.conversations.some((c) => c.id === conversationId);\n  }\n\n  /**\n   * Persist a specific conversation to storage.\n   */\n  private persistConversation(conversationId: string, throttleMs = 400): void {\n    const now = Date.now();\n    if (now - this.lastConversationPersistTs < throttleMs) return;\n    this.lastConversationPersistTs = now;\n\n    const conversation = this.conversations.find(\n      (c) => c.id === conversationId,\n    );\n    if (conversation) {\n      conversation.updatedAt = Date.now();\n\n      // Auto-generate name from first user message if still has default name\n      if (conversation.name.startsWith(\"Chat \")) {\n        const firstUserMsg = conversation.messages.find(\n          (m) => m.role === \"user\" && m.content.trim(),\n        );\n        if (firstUserMsg) {\n          let content = firstUserMsg.content\n            .replace(/\\[File:.*?\\][\\s\\S]*?```[\\s\\S]*?```/g, \"\")\n            .trim();\n\n          if (content) {\n            const preview = content.slice(0, 50);\n            conversation.name =\n              preview.length < content.length ? preview + \"...\" : preview;\n          }\n        }\n      }\n\n      this.saveConversationsToStorage();\n    }\n  }\n\n  /**\n   * Add a message directly to a specific conversation.\n   * Returns the message if added, null if conversation not found.\n   */\n  private addMessageToConversation(\n    conversationId: string,\n    role: \"user\" | \"assistant\",\n    content: string,\n  ): Message | null {\n    const conversation = this.conversations.find(\n      (c) => c.id === conversationId,\n    );\n    if (!conversation) return null;\n\n    const message: Message = {\n      id: generateUUID(),\n      role,\n      content,\n      timestamp: Date.now(),\n    };\n    conversation.messages.push(message);\n    return message;\n  }\n\n  /**\n   * Toggle sidebar visibility\n   */\n  toggleSidebar() {\n    this.isSidebarOpen = !this.isSidebarOpen;\n  }\n\n  setDebugMode(enabled: boolean) {\n    this.debugMode = enabled;\n    this.saveDebugModeToStorage();\n  }\n\n  toggleDebugMode() {\n    this.debugMode = !this.debugMode;\n    this.saveDebugModeToStorage();\n  }\n\n  getTopologyOnlyMode(): boolean {\n    return this.topologyOnlyMode;\n  }\n\n  setTopologyOnlyMode(enabled: boolean) {\n    this.topologyOnlyMode = enabled;\n    this.saveTopologyOnlyModeToStorage();\n  }\n\n  toggleTopologyOnlyMode() {\n    this.topologyOnlyMode = !this.topologyOnlyMode;\n    this.saveTopologyOnlyModeToStorage();\n  }\n\n  getChatSidebarVisible(): boolean {\n    return this.chatSidebarVisible;\n  }\n\n  setChatSidebarVisible(visible: boolean) {\n    this.chatSidebarVisible = visible;\n    this.saveChatSidebarVisibleToStorage();\n  }\n\n  toggleChatSidebarVisible() {\n    this.chatSidebarVisible = !this.chatSidebarVisible;\n    this.saveChatSidebarVisibleToStorage();\n  }\n\n  getMobileChatSidebarOpen(): boolean {\n    return this.mobileChatSidebarOpen;\n  }\n\n  setMobileChatSidebarOpen(open: boolean) {\n    this.mobileChatSidebarOpen = open;\n  }\n\n  toggleMobileChatSidebar() {\n    this.mobileChatSidebarOpen = !this.mobileChatSidebarOpen;\n  }\n\n  getMobileRightSidebarOpen(): boolean {\n    return this.mobileRightSidebarOpen;\n  }\n\n  setMobileRightSidebarOpen(open: boolean) {\n    this.mobileRightSidebarOpen = open;\n  }\n\n  toggleMobileRightSidebar() {\n    this.mobileRightSidebarOpen = !this.mobileRightSidebarOpen;\n  }\n\n  startPolling() {\n    this.fetchState();\n    this.fetchInterval = setInterval(() => this.fetchState(), 1000);\n  }\n\n  stopPolling() {\n    if (this.fetchInterval) {\n      clearInterval(this.fetchInterval);\n      this.fetchInterval = null;\n    }\n    this.stopPreviewsPolling();\n  }\n\n  async fetchState() {\n    try {\n      const response = await fetch(\"/state\");\n      if (!response.ok) {\n        throw new Error(`Failed to fetch state: ${response.status}`);\n      }\n      const data: RawStateResponse = await response.json();\n\n      if (data.topology) {\n        this.topologyData = transformTopology(data.topology, {\n          nodeIdentities: data.nodeIdentities,\n          nodeMemory: data.nodeMemory,\n          nodeSystem: data.nodeSystem,\n          nodeNetwork: data.nodeNetwork,\n        });\n        // Handle topology changes for preview filter\n        this.handleTopologyChange();\n      }\n      if (data.instances) {\n        this.instances = data.instances;\n        this.refreshConversationModelFromInstances();\n      }\n      if (data.runners) {\n        this.runners = data.runners;\n      }\n      if (data.downloads) {\n        this.downloads = data.downloads;\n      }\n      if (data.nodeDisk) {\n        this.nodeDisk = data.nodeDisk;\n      }\n      // Node identities (for OS version mismatch detection)\n      this.nodeIdentities = data.nodeIdentities ?? {};\n      // Thunderbolt identifiers per node\n      this.nodeThunderbolt = data.nodeThunderbolt ?? {};\n      // RDMA ctl status per node\n      this.nodeRdmaCtl = data.nodeRdmaCtl ?? {};\n      // Thunderbolt bridge cycles\n      this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];\n      // Thunderbolt bridge status per node\n      this.nodeThunderboltBridge = data.nodeThunderboltBridge ?? {};\n      this.lastUpdate = Date.now();\n      // Connection recovered\n      if (!this.isConnected) {\n        this.isConnected = true;\n      }\n      this.consecutiveFailures = 0;\n    } catch (error) {\n      this.consecutiveFailures++;\n      if (\n        this.consecutiveFailures >= AppStore.CONNECTION_LOST_THRESHOLD &&\n        this.isConnected\n      ) {\n        this.isConnected = false;\n      }\n      console.error(\"Error fetching state:\", error);\n    }\n  }\n\n  async fetchPlacementPreviews(modelId: string, showLoading = true) {\n    if (!modelId) return;\n\n    if (showLoading) {\n      this.isLoadingPreviews = true;\n    }\n    this.selectedPreviewModelId = modelId;\n\n    try {\n      let url = `/instance/previews?model_id=${encodeURIComponent(modelId)}`;\n      // Add node filter if active\n      if (this.previewNodeFilter.size > 0) {\n        for (const nodeId of this.previewNodeFilter) {\n          url += `&node_ids=${encodeURIComponent(nodeId)}`;\n        }\n      }\n      const response = await fetch(url);\n      if (!response.ok) {\n        throw new Error(\n          `Failed to fetch placement previews: ${response.status}`,\n        );\n      }\n      const data: PlacementPreviewResponse = await response.json();\n      this.placementPreviews = data.previews;\n    } catch (error) {\n      console.error(\"Error fetching placement previews:\", error);\n      this.placementPreviews = [];\n    } finally {\n      if (showLoading) {\n        this.isLoadingPreviews = false;\n      }\n    }\n  }\n\n  startPreviewsPolling(modelId: string) {\n    // Stop any existing preview polling\n    this.stopPreviewsPolling();\n\n    // Fetch immediately\n    this.fetchPlacementPreviews(modelId);\n\n    // Then poll every 15 seconds (don't show loading spinner for subsequent fetches)\n    this.previewsInterval = setInterval(() => {\n      if (this.selectedPreviewModelId) {\n        this.fetchPlacementPreviews(this.selectedPreviewModelId, false);\n      }\n    }, 15000);\n  }\n\n  stopPreviewsPolling() {\n    if (this.previewsInterval) {\n      clearInterval(this.previewsInterval);\n      this.previewsInterval = null;\n    }\n  }\n\n  selectPreviewModel(modelId: string | null) {\n    if (modelId) {\n      this.startPreviewsPolling(modelId);\n    } else {\n      this.stopPreviewsPolling();\n      this.selectedPreviewModelId = null;\n      this.placementPreviews = [];\n    }\n  }\n\n  /**\n   * Toggle a node in the preview filter and re-fetch placements\n   */\n  togglePreviewNodeFilter(nodeId: string) {\n    const next = new Set(this.previewNodeFilter);\n    if (next.has(nodeId)) {\n      next.delete(nodeId);\n    } else {\n      next.add(nodeId);\n    }\n    this.previewNodeFilter = next;\n    // Re-fetch with new filter if we have a selected model\n    if (this.selectedPreviewModelId) {\n      this.fetchPlacementPreviews(this.selectedPreviewModelId, false);\n    }\n  }\n\n  /**\n   * Clear the preview node filter and re-fetch placements\n   */\n  clearPreviewNodeFilter() {\n    this.previewNodeFilter = new Set();\n    // Re-fetch with no filter if we have a selected model\n    if (this.selectedPreviewModelId) {\n      this.fetchPlacementPreviews(this.selectedPreviewModelId, false);\n    }\n  }\n\n  /**\n   * Handle topology changes - clean up filter and re-fetch if needed\n   */\n  private handleTopologyChange() {\n    if (!this.topologyData) return;\n\n    const currentNodeIds = new Set(Object.keys(this.topologyData.nodes));\n\n    // Check if nodes have changed\n    const nodesAdded = [...currentNodeIds].some(\n      (id) => !this.previousNodeIds.has(id),\n    );\n    const nodesRemoved = [...this.previousNodeIds].some(\n      (id) => !currentNodeIds.has(id),\n    );\n\n    if (nodesAdded || nodesRemoved) {\n      // Clean up filter - remove any nodes that no longer exist\n      if (this.previewNodeFilter.size > 0) {\n        const validFilterNodes = new Set(\n          [...this.previewNodeFilter].filter((id) => currentNodeIds.has(id)),\n        );\n        if (validFilterNodes.size !== this.previewNodeFilter.size) {\n          this.previewNodeFilter = validFilterNodes;\n        }\n      }\n\n      // Re-fetch previews if we have a selected model (topology changed)\n      if (this.selectedPreviewModelId) {\n        this.fetchPlacementPreviews(this.selectedPreviewModelId, false);\n      }\n    }\n\n    // Update tracked node IDs for next comparison\n    this.previousNodeIds = currentNodeIds;\n  }\n\n  /**\n   * Starts a chat conversation - triggers the topology minimization animation\n   * Creates a new conversation if none is active\n   */\n  startChat() {\n    if (!this.activeConversationId) {\n      this.createConversation();\n    } else {\n      this.hasStartedChat = true;\n      this.isSidebarOpen = true; // Auto-open sidebar when chatting\n      // Small delay before minimizing for a nice visual effect\n      setTimeout(() => {\n        this.isTopologyMinimized = true;\n      }, 100);\n    }\n  }\n\n  /**\n   * Add a message to the conversation\n   */\n  addMessage(role: \"user\" | \"assistant\", content: string) {\n    const message: Message = {\n      id: generateUUID(),\n      role,\n      content,\n      timestamp: Date.now(),\n    };\n    this.messages.push(message);\n    return message;\n  }\n\n  /**\n   * Delete a message and all subsequent messages\n   */\n  deleteMessage(messageId: string) {\n    const messageIndex = this.messages.findIndex((m) => m.id === messageId);\n    if (messageIndex === -1) return;\n\n    // Remove this message and all subsequent messages\n    this.messages = this.messages.slice(0, messageIndex);\n    this.updateActiveConversation();\n  }\n\n  /**\n   * Edit a user message content (does not regenerate response)\n   */\n  editMessage(messageId: string, newContent: string) {\n    const message = this.messages.find((m) => m.id === messageId);\n    if (!message) return;\n\n    message.content = newContent;\n    message.timestamp = Date.now();\n    this.updateActiveConversation();\n  }\n\n  /**\n   * Edit a user message and regenerate the response\n   */\n  async editAndRegenerate(\n    messageId: string,\n    newContent: string,\n  ): Promise<void> {\n    const messageIndex = this.messages.findIndex((m) => m.id === messageId);\n    if (messageIndex === -1) return;\n\n    const message = this.messages[messageIndex];\n    if (message.role !== \"user\") return;\n\n    // Update the message content\n    message.content = newContent;\n    message.timestamp = Date.now();\n\n    // Remove all messages after this one (including the assistant response)\n    this.messages = this.messages.slice(0, messageIndex + 1);\n\n    // Regenerate the response\n    await this.regenerateLastResponse();\n  }\n\n  /**\n   * Regenerate the last assistant response\n   */\n  async regenerateLastResponse(): Promise<void> {\n    if (this.isLoading) return;\n\n    // Find the last user message\n    let lastUserIndex = -1;\n    for (let i = this.messages.length - 1; i >= 0; i--) {\n      if (this.messages[i].role === \"user\") {\n        lastUserIndex = i;\n        break;\n      }\n    }\n\n    if (lastUserIndex === -1) return;\n\n    const lastUserMessage = this.messages[lastUserIndex];\n    const requestType = lastUserMessage.requestType || \"chat\";\n    const prompt = lastUserMessage.content;\n\n    // Remove messages after user message (including the user message for image requests\n    // since generateImage/editImage will re-add it)\n    this.messages = this.messages.slice(0, lastUserIndex);\n    this.updateActiveConversation();\n\n    switch (requestType) {\n      case \"image-generation\":\n        await this.generateImage(prompt);\n        break;\n      case \"image-editing\":\n        if (lastUserMessage.sourceImageDataUrl) {\n          await this.editImage(prompt, lastUserMessage.sourceImageDataUrl);\n        } else {\n          // Can't regenerate edit without source image - restore user message and show error\n          this.messages.push(lastUserMessage);\n          const errorMessage = this.addMessage(\"assistant\", \"\");\n          const idx = this.messages.findIndex((m) => m.id === errorMessage.id);\n          if (idx !== -1) {\n            this.messages[idx].content =\n              \"Error: Cannot regenerate image edit - source image not found\";\n          }\n          this.updateActiveConversation();\n        }\n        break;\n      case \"chat\":\n      default:\n        // Restore the user message for chat regeneration\n        this.messages.push(lastUserMessage);\n        await this.regenerateChatCompletion();\n        break;\n    }\n  }\n\n  /**\n   * Regenerate response from a specific token index.\n   * Truncates the assistant message at the given token and re-generates from there.\n   */\n  async regenerateFromToken(\n    messageId: string,\n    tokenIndex: number,\n  ): Promise<void> {\n    if (this.isLoading) return;\n\n    const targetConversationId = this.activeConversationId;\n    if (!targetConversationId) return;\n\n    const msgIndex = this.messages.findIndex((m) => m.id === messageId);\n    if (msgIndex === -1) return;\n\n    const msg = this.messages[msgIndex];\n    if (\n      msg.role !== \"assistant\" ||\n      !msg.tokens ||\n      tokenIndex >= msg.tokens.length\n    )\n      return;\n\n    // Keep tokens up to (not including) the specified index\n    const tokensToKeep = msg.tokens.slice(0, tokenIndex);\n    const prefixText = tokensToKeep.map((t) => t.token).join(\"\");\n\n    // Remove all messages after this assistant message\n    this.messages = this.messages.slice(0, msgIndex + 1);\n\n    // Update the message to show the prefix\n    this.messages[msgIndex].content = prefixText;\n    this.messages[msgIndex].tokens = tokensToKeep;\n    this.updateActiveConversation();\n\n    // Set up for continuation - modify the existing message in place\n    this.isLoading = true;\n    this.currentResponse = prefixText;\n    this.ttftMs = null;\n    this.tps = null;\n    this.totalTokens = tokensToKeep.length;\n\n    try {\n      // Build messages for API - include the partial assistant message\n      const systemPrompt = {\n        role: \"system\" as const,\n        content:\n          \"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.\",\n      };\n\n      const apiMessages = [\n        systemPrompt,\n        ...this.messages.map((m) => {\n          let msgContent = m.content;\n          if (m.attachments) {\n            for (const attachment of m.attachments) {\n              if (attachment.type === \"text\" && attachment.content) {\n                msgContent += `\\n\\n[File: ${attachment.name}]\\n\\`\\`\\`\\n${attachment.content}\\n\\`\\`\\``;\n              }\n            }\n          }\n          return { role: m.role, content: msgContent };\n        }),\n      ];\n\n      const modelToUse = this.getModelForRequest();\n      if (!modelToUse) {\n        throw new Error(\"No model available\");\n      }\n\n      const requestStartTime = performance.now();\n      let firstTokenTime: number | null = null;\n      let tokenCount = tokensToKeep.length;\n\n      const response = await fetch(\"/v1/chat/completions\", {\n        method: \"POST\",\n        headers: { \"Content-Type\": \"application/json\" },\n        body: JSON.stringify({\n          model: modelToUse,\n          messages: apiMessages,\n          stream: true,\n          logprobs: true,\n          top_logprobs: 5,\n        }),\n      });\n\n      if (!response.ok) {\n        const errorText = await response.text();\n        throw new Error(`API error: ${response.status} - ${errorText}`);\n      }\n\n      const reader = response.body?.getReader();\n      if (!reader) throw new Error(\"No response body\");\n\n      let fullContent = prefixText;\n      let streamedThinking = \"\";\n      const collectedTokens: TokenData[] = [...tokensToKeep];\n\n      interface ChatCompletionChunk {\n        choices?: Array<{\n          delta?: { content?: string; reasoning_content?: string };\n          logprobs?: {\n            content?: Array<{\n              token: string;\n              logprob: number;\n              top_logprobs?: Array<{\n                token: string;\n                logprob: number;\n                bytes: number[] | null;\n              }>;\n            }>;\n          };\n        }>;\n      }\n\n      await this.parseSSEStream<ChatCompletionChunk>(\n        reader,\n        targetConversationId,\n        (parsed) => {\n          const choice = parsed.choices?.[0];\n          const delta = choice?.delta?.content;\n          const thinkingDelta = choice?.delta?.reasoning_content;\n\n          // Collect logprobs data\n          const logprobsContent = choice?.logprobs?.content;\n          if (logprobsContent) {\n            for (const item of logprobsContent) {\n              collectedTokens.push({\n                token: item.token,\n                logprob: item.logprob,\n                probability: Math.exp(item.logprob),\n                topLogprobs: (item.top_logprobs || []).map((t) => ({\n                  token: t.token,\n                  logprob: t.logprob,\n                  bytes: t.bytes,\n                })),\n              });\n            }\n          }\n\n          if (thinkingDelta) {\n            streamedThinking += thinkingDelta;\n          }\n\n          if (delta || thinkingDelta) {\n            if (firstTokenTime === null) {\n              firstTokenTime = performance.now();\n              this.ttftMs = firstTokenTime - requestStartTime;\n            }\n\n            tokenCount += 1;\n            this.totalTokens = tokenCount;\n\n            if (firstTokenTime !== null && tokenCount > tokensToKeep.length) {\n              const elapsed = performance.now() - firstTokenTime;\n              this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;\n            }\n\n            if (delta) {\n              fullContent += delta;\n            }\n            const { displayContent, thinkingContent: tagThinking } =\n              this.stripThinkingTags(fullContent);\n            const combinedThinking = [streamedThinking, tagThinking]\n              .filter(Boolean)\n              .join(\"\\n\\n\");\n\n            if (this.activeConversationId === targetConversationId) {\n              this.currentResponse = displayContent;\n            }\n\n            // Update existing message in place\n            this.updateConversationMessage(\n              targetConversationId,\n              messageId,\n              (m) => {\n                m.content = displayContent;\n                m.thinking = combinedThinking || undefined;\n                m.tokens = [...collectedTokens];\n              },\n            );\n            this.syncActiveMessagesIfNeeded(targetConversationId);\n            this.persistConversation(targetConversationId);\n          }\n        },\n      );\n\n      // Final update\n      if (this.conversationExists(targetConversationId)) {\n        const { displayContent, thinkingContent: tagThinking } =\n          this.stripThinkingTags(fullContent);\n        const finalThinking = [streamedThinking, tagThinking]\n          .filter(Boolean)\n          .join(\"\\n\\n\");\n        this.updateConversationMessage(targetConversationId, messageId, (m) => {\n          m.content = displayContent;\n          m.thinking = finalThinking || undefined;\n          m.tokens = [...collectedTokens];\n          if (this.ttftMs !== null) m.ttftMs = this.ttftMs;\n          if (this.tps !== null) m.tps = this.tps;\n        });\n        this.syncActiveMessagesIfNeeded(targetConversationId);\n        this.persistConversation(targetConversationId);\n      }\n    } catch (error) {\n      console.error(\"Error regenerating from token:\", error);\n      if (this.conversationExists(targetConversationId)) {\n        this.updateConversationMessage(targetConversationId, messageId, (m) => {\n          m.content = `${prefixText}\\n\\nError: ${error instanceof Error ? error.message : \"Unknown error\"}`;\n        });\n        this.syncActiveMessagesIfNeeded(targetConversationId);\n        this.persistConversation(targetConversationId);\n      }\n    } finally {\n      this.isLoading = false;\n      this.currentResponse = \"\";\n      this.saveConversationsToStorage();\n    }\n  }\n\n  /**\n   * Helper method to regenerate a chat completion response\n   */\n  private async regenerateChatCompletion(): Promise<void> {\n    // Capture the target conversation ID at the start of the request\n    const targetConversationId = this.activeConversationId;\n    if (!targetConversationId) return;\n\n    const targetConversation = this.conversations.find(\n      (c) => c.id === targetConversationId,\n    );\n    if (!targetConversation) return;\n\n    this.isLoading = true;\n    this.currentResponse = \"\";\n\n    // Create placeholder for assistant message directly in target conversation\n    const assistantMessage = this.addMessageToConversation(\n      targetConversationId,\n      \"assistant\",\n      \"\",\n    );\n    if (!assistantMessage) {\n      this.isLoading = false;\n      return;\n    }\n\n    // Sync to this.messages if viewing the target conversation\n    this.syncActiveMessagesIfNeeded(targetConversationId);\n\n    try {\n      const systemPrompt = {\n        role: \"system\" as const,\n        content:\n          \"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.\",\n      };\n\n      const apiMessages = [\n        systemPrompt,\n        ...targetConversation.messages.slice(0, -1).map((m) => {\n          return { role: m.role, content: m.content };\n        }),\n      ];\n\n      // Determine which model to use\n      const modelToUse = this.getModelForRequest();\n      if (!modelToUse) {\n        this.updateConversationMessage(\n          targetConversationId,\n          assistantMessage.id,\n          (msg) => {\n            msg.content =\n              \"No model is loaded yet. Select a model from the sidebar to get started — it will download and load automatically.\";\n          },\n        );\n        this.syncActiveMessagesIfNeeded(targetConversationId);\n        this.isLoading = false;\n        this.saveConversationsToStorage();\n        return;\n      }\n\n      const response = await fetch(\"/v1/chat/completions\", {\n        method: \"POST\",\n        headers: { \"Content-Type\": \"application/json\" },\n        body: JSON.stringify({\n          model: modelToUse,\n          messages: apiMessages,\n          stream: true,\n          logprobs: true,\n          top_logprobs: 5,\n        }),\n      });\n\n      if (!response.ok) {\n        const errorText = await response.text();\n        throw new Error(`${response.status} - ${errorText}`);\n      }\n\n      const reader = response.body?.getReader();\n      if (!reader) {\n        throw new Error(\"No response stream available\");\n      }\n\n      let streamedContent = \"\";\n      let streamedThinking = \"\";\n      const collectedTokens: TokenData[] = [];\n\n      interface ChatCompletionChunk {\n        choices?: Array<{\n          delta?: { content?: string; reasoning_content?: string };\n          logprobs?: {\n            content?: Array<{\n              token: string;\n              logprob: number;\n              top_logprobs?: Array<{\n                token: string;\n                logprob: number;\n                bytes: number[] | null;\n              }>;\n            }>;\n          };\n        }>;\n      }\n\n      await this.parseSSEStream<ChatCompletionChunk>(\n        reader,\n        targetConversationId,\n        (parsed) => {\n          const choice = parsed.choices?.[0];\n          const delta = choice?.delta?.content;\n          const thinkingDelta = choice?.delta?.reasoning_content;\n\n          // Collect logprobs data\n          const logprobsContent = choice?.logprobs?.content;\n          if (logprobsContent) {\n            for (const item of logprobsContent) {\n              collectedTokens.push({\n                token: item.token,\n                logprob: item.logprob,\n                probability: Math.exp(item.logprob),\n                topLogprobs: (item.top_logprobs || []).map((t) => ({\n                  token: t.token,\n                  logprob: t.logprob,\n                  bytes: t.bytes,\n                })),\n              });\n            }\n          }\n\n          if (thinkingDelta) {\n            streamedThinking += thinkingDelta;\n          }\n\n          if (delta || thinkingDelta) {\n            if (delta) {\n              streamedContent += delta;\n            }\n            const { displayContent, thinkingContent: tagThinking } =\n              this.stripThinkingTags(streamedContent);\n            const combinedThinking = [streamedThinking, tagThinking]\n              .filter(Boolean)\n              .join(\"\\n\\n\");\n\n            // Only update currentResponse if target conversation is active\n            if (this.activeConversationId === targetConversationId) {\n              this.currentResponse = displayContent;\n            }\n\n            // Update the assistant message in the target conversation\n            this.updateConversationMessage(\n              targetConversationId,\n              assistantMessage.id,\n              (msg) => {\n                msg.content = displayContent;\n                msg.thinking = combinedThinking || undefined;\n                msg.tokens = [...collectedTokens];\n              },\n            );\n            this.syncActiveMessagesIfNeeded(targetConversationId);\n            this.persistConversation(targetConversationId);\n          }\n        },\n      );\n\n      // Final cleanup of the message (if conversation still exists)\n      if (this.conversationExists(targetConversationId)) {\n        const { displayContent, thinkingContent: tagThinking } =\n          this.stripThinkingTags(streamedContent);\n        const finalThinking = [streamedThinking, tagThinking]\n          .filter(Boolean)\n          .join(\"\\n\\n\");\n        this.updateConversationMessage(\n          targetConversationId,\n          assistantMessage.id,\n          (msg) => {\n            msg.content = displayContent;\n            msg.thinking = finalThinking || undefined;\n            msg.tokens = [...collectedTokens];\n          },\n        );\n        this.syncActiveMessagesIfNeeded(targetConversationId);\n        this.persistConversation(targetConversationId);\n      }\n    } catch (error) {\n      this.handleStreamingError(\n        error,\n        targetConversationId,\n        assistantMessage.id,\n        \"Unknown error\",\n      );\n    } finally {\n      this.isLoading = false;\n      this.currentResponse = \"\";\n      this.saveConversationsToStorage();\n    }\n  }\n\n  /**\n   * Whether thinking is enabled for the current conversation\n   */\n  thinkingEnabled = $state(true);\n\n  /**\n   * Selected model for chat (can be set by the UI)\n   */\n  selectedChatModel = $state(\"\");\n\n  /**\n   * Set the model to use for chat\n   */\n  setSelectedModel(modelId: string) {\n    this.selectedChatModel = modelId;\n    // Clear stats when model changes\n    this.ttftMs = null;\n    this.tps = null;\n  }\n\n  /**\n   * Strip thinking tags from content for display.\n   * Handles both complete <think>...</think> blocks and in-progress <think>... blocks during streaming.\n   */\n  private stripThinkingTags(content: string): {\n    displayContent: string;\n    thinkingContent: string;\n  } {\n    const extracted: string[] = [];\n    let displayContent = content;\n\n    // Extract complete <think>...</think> blocks\n    const completeBlockRegex = /<think>([\\s\\S]*?)<\\/think>/gi;\n    let match: RegExpExecArray | null;\n    while ((match = completeBlockRegex.exec(content)) !== null) {\n      const inner = match[1]?.trim();\n      if (inner) extracted.push(inner);\n    }\n    displayContent = displayContent.replace(completeBlockRegex, \"\");\n\n    // Handle in-progress thinking block (has <think> but no closing </think> yet)\n    const openTagIndex = displayContent.lastIndexOf(\"<think>\");\n    if (openTagIndex !== -1) {\n      const inProgressThinking = displayContent.slice(openTagIndex + 7).trim();\n      if (inProgressThinking) {\n        extracted.push(inProgressThinking);\n      }\n      displayContent = displayContent.slice(0, openTagIndex);\n    }\n\n    return {\n      displayContent: displayContent.trim(),\n      thinkingContent: extracted.join(\"\\n\\n\"),\n    };\n  }\n\n  /**\n   * Parse an SSE stream and invoke a callback for each parsed JSON chunk.\n   * Handles buffering, line splitting, and conversation deletion checks.\n   *\n   * @param reader - The stream reader from fetch response.body.getReader()\n   * @param targetConversationId - The conversation ID to check for deletion\n   * @param onChunk - Callback invoked with each parsed JSON object from the stream\n   */\n  private async parseSSEStream<T>(\n    reader: ReadableStreamDefaultReader<Uint8Array>,\n    targetConversationId: string,\n    onChunk: (parsed: T) => void,\n    onEvent?: Record<string, (data: unknown) => void>,\n  ): Promise<void> {\n    const decoder = new TextDecoder();\n    let buffer = \"\";\n\n    while (true) {\n      const { done, value } = await reader.read();\n      if (done) break;\n\n      if (!this.conversationExists(targetConversationId)) {\n        break;\n      }\n\n      buffer += decoder.decode(value, { stream: true });\n      const lines = buffer.split(\"\\n\");\n      buffer = lines.pop() || \"\";\n\n      for (const line of lines) {\n        const trimmed = line.trim();\n        if (!trimmed) continue;\n\n        // Handle SSE comments (\": key json\") for prefill progress etc.\n        if (trimmed.startsWith(\": \") && onEvent) {\n          const comment = trimmed.slice(2);\n          const spaceIdx = comment.indexOf(\" \");\n          if (spaceIdx > 0) {\n            const key = comment.slice(0, spaceIdx);\n            if (onEvent[key]) {\n              try {\n                const parsed = JSON.parse(comment.slice(spaceIdx + 1));\n                onEvent[key](parsed);\n              } catch {\n                // Skip malformed JSON in comment\n              }\n            }\n          }\n          continue;\n        }\n\n        if (trimmed.startsWith(\"data: \")) {\n          const data = trimmed.slice(6);\n          if (data === \"[DONE]\") continue;\n\n          try {\n            const parsed = JSON.parse(data) as T;\n            onChunk(parsed);\n          } catch {\n            // Skip malformed JSON\n          }\n        }\n      }\n    }\n\n    // Process any remaining data in the buffer\n    if (buffer.trim() && this.conversationExists(targetConversationId)) {\n      const trimmed = buffer.trim();\n      if (trimmed.startsWith(\"data: \") && trimmed.slice(6) !== \"[DONE]\") {\n        try {\n          const parsed = JSON.parse(trimmed.slice(6)) as T;\n          onChunk(parsed);\n        } catch {\n          // Skip malformed JSON\n        }\n      }\n    }\n  }\n\n  /**\n   * Handle streaming errors by updating the assistant message with an error.\n   *\n   * @param error - The caught error\n   * @param targetConversationId - The conversation ID\n   * @param assistantMessageId - The assistant message ID to update\n   * @param errorPrefix - Optional prefix for the error message (e.g., \"Failed to generate image\")\n   */\n  private handleStreamingError(\n    error: unknown,\n    targetConversationId: string,\n    assistantMessageId: string,\n    errorPrefix = \"Failed to get response\",\n  ): void {\n    if (this.conversationExists(targetConversationId)) {\n      this.updateConversationMessage(\n        targetConversationId,\n        assistantMessageId,\n        (msg) => {\n          msg.content = `Error: ${error instanceof Error ? error.message : errorPrefix}`;\n        },\n      );\n      this.syncActiveMessagesIfNeeded(targetConversationId);\n      this.persistConversation(targetConversationId);\n    }\n  }\n\n  /**\n   * Get the model to use for a request.\n   * Prefers the provided modelId, then selectedChatModel, then falls back to the first running instance.\n   *\n   * @param modelId - Optional explicit model ID\n   * @returns The model ID to use, or null if none available\n   */\n  private getModelForRequest(modelId?: string): string | null {\n    if (modelId) return modelId;\n    if (this.selectedChatModel) return this.selectedChatModel;\n\n    // Try to get model from first running instance\n    for (const [, instanceWrapper] of Object.entries(this.instances)) {\n      if (instanceWrapper && typeof instanceWrapper === \"object\") {\n        const keys = Object.keys(instanceWrapper as Record<string, unknown>);\n        if (keys.length === 1) {\n          const instance = (instanceWrapper as Record<string, unknown>)[\n            keys[0]\n          ] as { shardAssignments?: { modelId?: string } };\n          if (instance?.shardAssignments?.modelId) {\n            return instance.shardAssignments.modelId;\n          }\n        }\n      }\n    }\n    return null;\n  }\n\n  /**\n   * Send a message to the LLM and stream the response\n   */\n  async sendMessage(\n    content: string,\n    files?: {\n      id: string;\n      name: string;\n      type: string;\n      textContent?: string;\n      preview?: string;\n    }[],\n    enableThinking?: boolean | null,\n  ): Promise<void> {\n    if ((!content.trim() && (!files || files.length === 0)) || this.isLoading)\n      return;\n\n    if (!this.hasStartedChat) {\n      this.startChat();\n    }\n\n    // Capture the target conversation ID at the start of the request\n    const targetConversationId = this.activeConversationId;\n    if (!targetConversationId) return;\n\n    this.isLoading = true;\n    this.currentResponse = \"\";\n    this.ttftMs = null;\n    this.tps = null;\n    this.totalTokens = 0;\n\n    // Build attachments from files\n    const attachments: MessageAttachment[] = [];\n    let fileContext = \"\";\n\n    if (files && files.length > 0) {\n      for (const file of files) {\n        const isImage = file.type.startsWith(\"image/\");\n\n        if (isImage && file.preview) {\n          attachments.push({\n            type: \"image\",\n            name: file.name,\n            preview: file.preview,\n            mimeType: file.type,\n          });\n        } else if (file.textContent) {\n          attachments.push({\n            type: \"text\",\n            name: file.name,\n            content: file.textContent,\n            mimeType: file.type,\n          });\n          // Add text file content to the message context\n          fileContext += `\\n\\n[File: ${file.name}]\\n\\`\\`\\`\\n${file.textContent}\\n\\`\\`\\``;\n        } else {\n          attachments.push({\n            type: \"file\",\n            name: file.name,\n            mimeType: file.type,\n          });\n        }\n      }\n    }\n\n    // Combine content with file context\n    const fullContent = content + fileContext;\n\n    // Add user message directly to the target conversation\n    const userMessage: Message = {\n      id: generateUUID(),\n      role: \"user\",\n      content: content, // Store original content for display\n      timestamp: Date.now(),\n      attachments: attachments.length > 0 ? attachments : undefined,\n    };\n\n    const targetConversation = this.conversations.find(\n      (c) => c.id === targetConversationId,\n    );\n    if (!targetConversation) {\n      this.isLoading = false;\n      return;\n    }\n    targetConversation.messages.push(userMessage);\n\n    // Create placeholder for assistant message directly in target conversation\n    const assistantMessage = this.addMessageToConversation(\n      targetConversationId,\n      \"assistant\",\n      \"\",\n    );\n    if (!assistantMessage) {\n      this.isLoading = false;\n      return;\n    }\n\n    // Sync to this.messages if viewing the target conversation\n    this.syncActiveMessagesIfNeeded(targetConversationId);\n    this.saveConversationsToStorage();\n\n    try {\n      // Build the messages array for the API with system prompt\n      const systemPrompt = {\n        role: \"system\" as const,\n        content:\n          \"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process. When files are shared with you, analyze them and respond helpfully.\",\n      };\n\n      // Build API messages from the target conversation - include file content for text files\n      const apiMessages = [\n        systemPrompt,\n        ...targetConversation.messages.slice(0, -1).map((m) => {\n          // Build content including any text file attachments\n          let msgContent = m.content;\n\n          // Add text attachments as context\n          if (m.attachments) {\n            for (const attachment of m.attachments) {\n              if (attachment.type === \"text\" && attachment.content) {\n                msgContent += `\\n\\n[File: ${attachment.name}]\\n\\`\\`\\`\\n${attachment.content}\\n\\`\\`\\``;\n              }\n            }\n          }\n\n          return {\n            role: m.role,\n            content: msgContent,\n          };\n        }),\n      ];\n\n      // Determine the model to use\n      const modelToUse = this.getModelForRequest();\n      if (!modelToUse) {\n        throw new Error(\n          \"No model is loaded yet. Select a model from the sidebar to get started — it will download and load automatically.\",\n        );\n      }\n\n      const conversationModelInfo = this.buildConversationModelInfo(modelToUse);\n      this.applyConversationModelInfo(conversationModelInfo);\n\n      // Start timing for TTFT measurement\n      const requestStartTime = performance.now();\n      let firstTokenTime: number | null = null;\n      let tokenCount = 0;\n\n      const abortController = new AbortController();\n      this.currentAbortController = abortController;\n\n      const response = await fetch(\"/v1/chat/completions\", {\n        method: \"POST\",\n        headers: {\n          \"Content-Type\": \"application/json\",\n        },\n        body: JSON.stringify({\n          model: modelToUse,\n          messages: apiMessages,\n          temperature: 0.7,\n          stream: true,\n          logprobs: true,\n          top_logprobs: 5,\n          ...(enableThinking != null && {\n            enable_thinking: enableThinking,\n          }),\n        }),\n        signal: abortController.signal,\n      });\n\n      if (!response.ok) {\n        const errorText = await response.text();\n        throw new Error(`API error: ${response.status} - ${errorText}`);\n      }\n\n      const reader = response.body?.getReader();\n      if (!reader) {\n        throw new Error(\"No response body\");\n      }\n\n      let streamedContent = \"\";\n      let streamedThinking = \"\";\n\n      interface ChatCompletionChunk {\n        choices?: Array<{\n          delta?: { content?: string; reasoning_content?: string };\n          logprobs?: {\n            content?: Array<{\n              token: string;\n              logprob: number;\n              top_logprobs?: Array<{\n                token: string;\n                logprob: number;\n                bytes: number[] | null;\n              }>;\n            }>;\n          };\n        }>;\n      }\n\n      const collectedTokens: TokenData[] = [];\n\n      await this.parseSSEStream<ChatCompletionChunk>(\n        reader,\n        targetConversationId,\n        (parsed) => {\n          // Clear prefill progress when first token data arrives\n          if (this.prefillProgress) {\n            this.prefillProgress = null;\n          }\n\n          const choice = parsed.choices?.[0];\n          const tokenContent = choice?.delta?.content;\n          const thinkingContent = choice?.delta?.reasoning_content;\n\n          // Collect logprobs data\n          const logprobsContent = choice?.logprobs?.content;\n          if (logprobsContent) {\n            for (const item of logprobsContent) {\n              collectedTokens.push({\n                token: item.token,\n                logprob: item.logprob,\n                probability: Math.exp(item.logprob),\n                topLogprobs: (item.top_logprobs || []).map((t) => ({\n                  token: t.token,\n                  logprob: t.logprob,\n                  bytes: t.bytes,\n                })),\n              });\n            }\n          }\n\n          if (thinkingContent) {\n            streamedThinking += thinkingContent;\n          }\n\n          if (tokenContent || thinkingContent) {\n            // Track first token for TTFT\n            if (firstTokenTime === null) {\n              firstTokenTime = performance.now();\n              this.ttftMs = firstTokenTime - requestStartTime;\n            }\n\n            // Count tokens (each SSE chunk is typically one token)\n            tokenCount += 1;\n            this.totalTokens = tokenCount;\n\n            // Update real-time TPS during streaming\n            if (firstTokenTime !== null && tokenCount > 1) {\n              const elapsed = performance.now() - firstTokenTime;\n              this.tps = (tokenCount / elapsed) * 1000;\n            }\n\n            if (tokenContent) {\n              streamedContent += tokenContent;\n            }\n\n            // Use stripThinkingTags as fallback for any <think> tags still in content\n            const { displayContent, thinkingContent: tagThinking } =\n              this.stripThinkingTags(streamedContent);\n            const combinedThinking = [streamedThinking, tagThinking]\n              .filter(Boolean)\n              .join(\"\\n\\n\");\n\n            // Only update currentResponse if target conversation is active\n            if (this.activeConversationId === targetConversationId) {\n              this.currentResponse = displayContent;\n            }\n\n            // Update the assistant message in the target conversation\n            this.updateConversationMessage(\n              targetConversationId,\n              assistantMessage.id,\n              (msg) => {\n                msg.content = displayContent;\n                msg.thinking = combinedThinking || undefined;\n                msg.tokens = [...collectedTokens];\n              },\n            );\n            this.syncActiveMessagesIfNeeded(targetConversationId);\n            this.persistConversation(targetConversationId);\n          }\n        },\n        {\n          prefill_progress: (data) => {\n            // TaggedModel wraps as {\"PrefillProgressChunk\": {...}}\n            // model_dump_json() uses snake_case (by_alias defaults to False)\n            const raw = data as Record<string, unknown>;\n            const inner = (raw[\"PrefillProgressChunk\"] ?? raw) as {\n              processed_tokens: number;\n              total_tokens: number;\n            };\n            this.prefillProgress = {\n              processed: inner.processed_tokens,\n              total: inner.total_tokens,\n              startedAt: this.prefillProgress?.startedAt ?? performance.now(),\n            };\n          },\n        },\n      );\n\n      // Clear prefill progress after stream ends\n      this.prefillProgress = null;\n\n      // Calculate final TPS\n      if (firstTokenTime !== null && tokenCount > 1) {\n        const totalGenerationTime = performance.now() - firstTokenTime;\n        this.tps = (tokenCount / totalGenerationTime) * 1000; // tokens per second\n      }\n\n      // Final cleanup of the message (if conversation still exists)\n      if (this.conversationExists(targetConversationId)) {\n        const { displayContent, thinkingContent: tagThinking } =\n          this.stripThinkingTags(streamedContent);\n        const finalThinking = [streamedThinking, tagThinking]\n          .filter(Boolean)\n          .join(\"\\n\\n\");\n        this.updateConversationMessage(\n          targetConversationId,\n          assistantMessage.id,\n          (msg) => {\n            msg.content = displayContent;\n            msg.thinking = finalThinking || undefined;\n            msg.tokens = [...collectedTokens];\n            // Store performance metrics on the message\n            if (this.ttftMs !== null) {\n              msg.ttftMs = this.ttftMs;\n            }\n            if (this.tps !== null) {\n              msg.tps = this.tps;\n            }\n          },\n        );\n        this.syncActiveMessagesIfNeeded(targetConversationId);\n        this.persistConversation(targetConversationId);\n      }\n    } catch (error) {\n      if (error instanceof DOMException && error.name === \"AbortError\") {\n        // User stopped generation — not an error\n      } else {\n        console.error(\"Error sending message:\", error);\n        this.handleStreamingError(\n          error,\n          targetConversationId,\n          assistantMessage.id,\n          \"Failed to get response\",\n        );\n      }\n    } finally {\n      this.currentAbortController = null;\n      this.prefillProgress = null;\n      this.isLoading = false;\n      this.currentResponse = \"\";\n      this.saveConversationsToStorage();\n    }\n  }\n\n  stopGeneration(): void {\n    this.currentAbortController?.abort();\n    this.currentAbortController = null;\n  }\n\n  /**\n   * Generate an image using the image generation API\n   */\n  async generateImage(prompt: string, modelId?: string): Promise<void> {\n    if (!prompt.trim() || this.isLoading) return;\n\n    if (!this.hasStartedChat) {\n      this.startChat();\n    }\n\n    // Capture the target conversation ID at the start of the request\n    const targetConversationId = this.activeConversationId;\n    if (!targetConversationId) return;\n\n    this.isLoading = true;\n    this.currentResponse = \"\";\n\n    // Add user message directly to the target conversation\n    const userMessage: Message = {\n      id: generateUUID(),\n      role: \"user\",\n      content: prompt,\n      timestamp: Date.now(),\n      requestType: \"image-generation\",\n    };\n\n    const targetConversation = this.conversations.find(\n      (c) => c.id === targetConversationId,\n    );\n    if (!targetConversation) {\n      this.isLoading = false;\n      return;\n    }\n    targetConversation.messages.push(userMessage);\n\n    // Create placeholder for assistant message directly in target conversation\n    const assistantMessage = this.addMessageToConversation(\n      targetConversationId,\n      \"assistant\",\n      \"Generating image...\",\n    );\n    if (!assistantMessage) {\n      this.isLoading = false;\n      return;\n    }\n\n    // Sync to this.messages if viewing the target conversation\n    this.syncActiveMessagesIfNeeded(targetConversationId);\n    this.saveConversationsToStorage();\n\n    try {\n      // Determine the model to use\n      const model = this.getModelForRequest(modelId);\n      if (!model) {\n        throw new Error(\n          \"No model selected. Please select an image generation model.\",\n        );\n      }\n\n      // Build request body using image generation params\n      const params = this.imageGenerationParams;\n      const hasAdvancedParams =\n        params.seed !== null ||\n        params.numInferenceSteps !== null ||\n        params.guidance !== null ||\n        (params.negativePrompt !== null &&\n          params.negativePrompt.trim() !== \"\") ||\n        params.numSyncSteps !== null;\n\n      const requestBody: Record<string, unknown> = {\n        model,\n        prompt,\n        n: params.numImages,\n        quality: params.quality,\n        size: params.size,\n        output_format: params.outputFormat,\n        response_format: \"b64_json\",\n        stream: params.stream,\n        partial_images: params.partialImages,\n      };\n\n      if (hasAdvancedParams) {\n        requestBody.advanced_params = {\n          ...(params.seed !== null && { seed: params.seed }),\n          ...(params.numInferenceSteps !== null && {\n            num_inference_steps: params.numInferenceSteps,\n          }),\n          ...(params.guidance !== null && { guidance: params.guidance }),\n          ...(params.negativePrompt !== null &&\n            params.negativePrompt.trim() !== \"\" && {\n              negative_prompt: params.negativePrompt,\n            }),\n          ...(params.numSyncSteps !== null && {\n            num_sync_steps: params.numSyncSteps,\n          }),\n        };\n      }\n\n      const response = await fetch(\"/v1/images/generations\", {\n        method: \"POST\",\n        headers: {\n          \"Content-Type\": \"application/json\",\n        },\n        body: JSON.stringify(requestBody),\n      });\n\n      if (!response.ok) {\n        const errorText = await response.text();\n        throw new Error(`API error: ${response.status} - ${errorText}`);\n      }\n\n      // Streaming requires both stream=true AND partialImages > 0\n      const isStreaming = params.stream && params.partialImages > 0;\n\n      if (!isStreaming) {\n        // Non-streaming: parse JSON response directly\n        const jsonResponse = (await response.json()) as ImageApiResponse;\n        const format = params.outputFormat || \"png\";\n        const mimeType = `image/${format}`;\n\n        const attachments: MessageAttachment[] = jsonResponse.data\n          .filter((img) => img.b64_json)\n          .map((img, index) => ({\n            type: \"generated-image\" as const,\n            name: `generated-image-${index + 1}.${format}`,\n            preview: `data:${mimeType};base64,${img.b64_json}`,\n            mimeType,\n          }));\n\n        this.updateConversationMessage(\n          targetConversationId,\n          assistantMessage.id,\n          (msg) => {\n            msg.content = \"\";\n            msg.attachments = attachments;\n          },\n        );\n        this.syncActiveMessagesIfNeeded(targetConversationId);\n      } else {\n        // Streaming mode: use SSE parser\n        const reader = response.body?.getReader();\n        if (!reader) {\n          throw new Error(\"No response body\");\n        }\n\n        interface ImageGenerationChunk {\n          data?: { b64_json?: string };\n          format?: string;\n          type?: \"partial\" | \"final\";\n          image_index?: number;\n          partial_index?: number;\n          total_partials?: number;\n        }\n\n        const numImages = params.numImages;\n\n        await this.parseSSEStream<ImageGenerationChunk>(\n          reader,\n          targetConversationId,\n          (parsed) => {\n            const imageData = parsed.data?.b64_json;\n\n            if (imageData) {\n              const format = parsed.format || \"png\";\n              const mimeType = `image/${format}`;\n              const imageIndex = parsed.image_index ?? 0;\n\n              if (parsed.type === \"partial\") {\n                // Update with partial image and progress\n                const partialNum = (parsed.partial_index ?? 0) + 1;\n                const totalPartials = parsed.total_partials ?? 3;\n                const progressText =\n                  numImages > 1\n                    ? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`\n                    : `Generating... ${partialNum}/${totalPartials}`;\n\n                const partialAttachment: MessageAttachment = {\n                  type: \"generated-image\",\n                  name: `generated-image.${format}`,\n                  preview: `data:${mimeType};base64,${imageData}`,\n                  mimeType,\n                };\n\n                this.updateConversationMessage(\n                  targetConversationId,\n                  assistantMessage.id,\n                  (msg) => {\n                    msg.content = progressText;\n                    if (imageIndex === 0) {\n                      // First image - safe to replace attachments with partial preview\n                      msg.attachments = [partialAttachment];\n                    } else {\n                      // Subsequent images - keep existing finals, show partial at current position\n                      const existingAttachments = msg.attachments || [];\n                      // Keep only the completed final images (up to current imageIndex)\n                      const finals = existingAttachments.slice(0, imageIndex);\n                      msg.attachments = [...finals, partialAttachment];\n                    }\n                  },\n                );\n              } else if (parsed.type === \"final\") {\n                // Final image - replace partial at this position\n                const newAttachment: MessageAttachment = {\n                  type: \"generated-image\",\n                  name: `generated-image-${imageIndex + 1}.${format}`,\n                  preview: `data:${mimeType};base64,${imageData}`,\n                  mimeType,\n                };\n\n                this.updateConversationMessage(\n                  targetConversationId,\n                  assistantMessage.id,\n                  (msg) => {\n                    if (imageIndex === 0) {\n                      // First final image - replace any partial preview\n                      msg.attachments = [newAttachment];\n                    } else {\n                      // Subsequent images - keep previous finals, replace partial at current position\n                      const existingAttachments = msg.attachments || [];\n                      // Slice keeps indices 0 to imageIndex-1 (the previous final images)\n                      const previousFinals = existingAttachments.slice(\n                        0,\n                        imageIndex,\n                      );\n                      msg.attachments = [...previousFinals, newAttachment];\n                    }\n\n                    // Update progress message for multiple images\n                    if (numImages > 1 && imageIndex < numImages - 1) {\n                      msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;\n                    } else {\n                      msg.content = \"\";\n                    }\n                  },\n                );\n              }\n\n              this.syncActiveMessagesIfNeeded(targetConversationId);\n            }\n          },\n        );\n      }\n    } catch (error) {\n      console.error(\"Error generating image:\", error);\n      this.handleStreamingError(\n        error,\n        targetConversationId,\n        assistantMessage.id,\n        \"Failed to generate image\",\n      );\n    } finally {\n      this.isLoading = false;\n      this.saveConversationsToStorage();\n    }\n  }\n\n  /**\n   * Edit an image using the image edit API\n   */\n  async editImage(\n    prompt: string,\n    imageDataUrl: string,\n    modelId?: string,\n  ): Promise<void> {\n    if (!prompt.trim() || !imageDataUrl || this.isLoading) return;\n\n    if (!this.hasStartedChat) {\n      this.startChat();\n    }\n\n    // Capture the target conversation ID at the start of the request\n    const targetConversationId = this.activeConversationId;\n    if (!targetConversationId) return;\n\n    this.isLoading = true;\n    this.currentResponse = \"\";\n\n    // Add user message directly to the target conversation\n    const userMessage: Message = {\n      id: generateUUID(),\n      role: \"user\",\n      content: prompt,\n      timestamp: Date.now(),\n      requestType: \"image-editing\",\n      sourceImageDataUrl: imageDataUrl,\n    };\n\n    const targetConversation = this.conversations.find(\n      (c) => c.id === targetConversationId,\n    );\n    if (!targetConversation) {\n      this.isLoading = false;\n      return;\n    }\n    targetConversation.messages.push(userMessage);\n\n    // Create placeholder for assistant message directly in target conversation\n    const assistantMessage = this.addMessageToConversation(\n      targetConversationId,\n      \"assistant\",\n      \"Editing image...\",\n    );\n    if (!assistantMessage) {\n      this.isLoading = false;\n      return;\n    }\n\n    // Sync to this.messages if viewing the target conversation\n    this.syncActiveMessagesIfNeeded(targetConversationId);\n    this.saveConversationsToStorage();\n\n    // Clear editing state\n    this.editingImage = null;\n\n    try {\n      // Determine the model to use\n      const model = this.getModelForRequest(modelId);\n      if (!model) {\n        throw new Error(\n          \"No model selected. Please select an image generation model.\",\n        );\n      }\n\n      // Convert base64 data URL to blob\n      const response = await fetch(imageDataUrl);\n      const imageBlob = await response.blob();\n\n      // Build FormData request\n      const formData = new FormData();\n      formData.append(\"model\", model);\n      formData.append(\"prompt\", prompt);\n      formData.append(\"image\", imageBlob, \"image.png\");\n\n      // Add params from image generation params\n      const params = this.imageGenerationParams;\n      formData.append(\"quality\", params.quality);\n      formData.append(\"size\", params.size);\n      formData.append(\"output_format\", params.outputFormat);\n      formData.append(\"response_format\", \"b64_json\");\n      formData.append(\"stream\", params.stream ? \"1\" : \"0\");\n      formData.append(\"partial_images\", params.partialImages.toString());\n      formData.append(\"input_fidelity\", params.inputFidelity);\n\n      // Advanced params\n      const hasAdvancedParams =\n        params.seed !== null ||\n        params.numInferenceSteps !== null ||\n        params.guidance !== null ||\n        (params.negativePrompt !== null &&\n          params.negativePrompt.trim() !== \"\") ||\n        params.numSyncSteps !== null;\n\n      if (hasAdvancedParams) {\n        formData.append(\n          \"advanced_params\",\n          JSON.stringify({\n            ...(params.seed !== null && { seed: params.seed }),\n            ...(params.numInferenceSteps !== null && {\n              num_inference_steps: params.numInferenceSteps,\n            }),\n            ...(params.guidance !== null && { guidance: params.guidance }),\n            ...(params.negativePrompt !== null &&\n              params.negativePrompt.trim() !== \"\" && {\n                negative_prompt: params.negativePrompt,\n              }),\n            ...(params.numSyncSteps !== null && {\n              num_sync_steps: params.numSyncSteps,\n            }),\n          }),\n        );\n      }\n\n      const apiResponse = await fetch(\"/v1/images/edits\", {\n        method: \"POST\",\n        body: formData,\n      });\n\n      if (!apiResponse.ok) {\n        const errorText = await apiResponse.text();\n        throw new Error(`API error: ${apiResponse.status} - ${errorText}`);\n      }\n\n      // Streaming requires both stream=true AND partialImages > 0\n      const isStreaming = params.stream && params.partialImages > 0;\n\n      if (!isStreaming) {\n        // Non-streaming: parse JSON response directly\n        const jsonResponse = (await apiResponse.json()) as ImageApiResponse;\n        const format = params.outputFormat || \"png\";\n        const mimeType = `image/${format}`;\n        const attachments: MessageAttachment[] = jsonResponse.data\n          .filter((img) => img.b64_json)\n          .map((img) => ({\n            type: \"generated-image\" as const,\n            name: `edited-image.${format}`,\n            preview: `data:${mimeType};base64,${img.b64_json}`,\n            mimeType,\n          }));\n\n        this.updateConversationMessage(\n          targetConversationId,\n          assistantMessage.id,\n          (msg) => {\n            msg.content = \"\";\n            msg.attachments = attachments;\n          },\n        );\n        this.syncActiveMessagesIfNeeded(targetConversationId);\n      } else {\n        // Streaming mode: use SSE parser\n        const reader = apiResponse.body?.getReader();\n        if (!reader) {\n          throw new Error(\"No response body\");\n        }\n\n        interface ImageEditChunk {\n          data?: { b64_json?: string };\n          format?: string;\n          type?: \"partial\" | \"final\";\n          partial_index?: number;\n          total_partials?: number;\n        }\n\n        await this.parseSSEStream<ImageEditChunk>(\n          reader,\n          targetConversationId,\n          (parsed) => {\n            const imageData = parsed.data?.b64_json;\n\n            if (imageData) {\n              const format = parsed.format || \"png\";\n              const mimeType = `image/${format}`;\n              if (parsed.type === \"partial\") {\n                // Update with partial image and progress\n                const partialNum = (parsed.partial_index ?? 0) + 1;\n                const totalPartials = parsed.total_partials ?? 3;\n                this.updateConversationMessage(\n                  targetConversationId,\n                  assistantMessage.id,\n                  (msg) => {\n                    msg.content = `Editing... ${partialNum}/${totalPartials}`;\n                    msg.attachments = [\n                      {\n                        type: \"generated-image\",\n                        name: `edited-image.${format}`,\n                        preview: `data:${mimeType};base64,${imageData}`,\n                        mimeType,\n                      },\n                    ];\n                  },\n                );\n              } else if (parsed.type === \"final\") {\n                // Final image\n                this.updateConversationMessage(\n                  targetConversationId,\n                  assistantMessage.id,\n                  (msg) => {\n                    msg.content = \"\";\n                    msg.attachments = [\n                      {\n                        type: \"generated-image\",\n                        name: `edited-image.${format}`,\n                        preview: `data:${mimeType};base64,${imageData}`,\n                        mimeType,\n                      },\n                    ];\n                  },\n                );\n              }\n              this.syncActiveMessagesIfNeeded(targetConversationId);\n            }\n          },\n        );\n      }\n    } catch (error) {\n      console.error(\"Error editing image:\", error);\n      this.handleStreamingError(\n        error,\n        targetConversationId,\n        assistantMessage.id,\n        \"Failed to edit image\",\n      );\n    } finally {\n      this.isLoading = false;\n      this.saveConversationsToStorage();\n    }\n  }\n\n  /**\n   * Clear current chat and go back to welcome state\n   */\n  clearChat() {\n    this.activeConversationId = null;\n    this.messages = [];\n    this.hasStartedChat = false;\n    this.isTopologyMinimized = false;\n    this.currentResponse = \"\";\n    // Clear performance stats\n    this.ttftMs = null;\n    this.tps = null;\n  }\n\n  /**\n   * Get the active conversation\n   */\n  getActiveConversation(): Conversation | null {\n    if (!this.activeConversationId) return null;\n    return (\n      this.conversations.find((c) => c.id === this.activeConversationId) || null\n    );\n  }\n\n  /**\n   * Update the thinking preference for the active conversation\n   */\n  setConversationThinking(enabled: boolean) {\n    this.thinkingEnabled = enabled;\n    const conv = this.getActiveConversation();\n    if (conv) {\n      conv.enableThinking = enabled;\n      this.saveConversationsToStorage();\n    }\n  }\n\n  /**\n   * Start a download on a specific node\n   */\n  async startDownload(nodeId: string, shardMetadata: object): Promise<void> {\n    try {\n      const response = await fetch(\"/download/start\", {\n        method: \"POST\",\n        headers: { \"Content-Type\": \"application/json\" },\n        body: JSON.stringify({\n          targetNodeId: nodeId,\n          shardMetadata: shardMetadata,\n        }),\n      });\n      if (!response.ok) {\n        const errorText = await response.text();\n        throw new Error(\n          `Failed to start download: ${response.status} - ${errorText}`,\n        );\n      }\n    } catch (error) {\n      console.error(\"Error starting download:\", error);\n      throw error;\n    }\n  }\n\n  /**\n   * Delete a downloaded model from a specific node\n   */\n  async deleteDownload(nodeId: string, modelId: string): Promise<void> {\n    try {\n      const response = await fetch(\n        `/download/${encodeURIComponent(nodeId)}/${encodeURIComponent(modelId)}`,\n        {\n          method: \"DELETE\",\n        },\n      );\n      if (!response.ok) {\n        const errorText = await response.text();\n        throw new Error(\n          `Failed to delete download: ${response.status} - ${errorText}`,\n        );\n      }\n    } catch (error) {\n      console.error(\"Error deleting download:\", error);\n      throw error;\n    }\n  }\n\n  /**\n   * List all available traces\n   */\n  async listTraces(): Promise<TraceListResponse> {\n    const response = await fetch(\"/v1/traces\");\n    if (!response.ok) {\n      throw new Error(`Failed to list traces: ${response.status}`);\n    }\n    return (await response.json()) as TraceListResponse;\n  }\n\n  /**\n   * Check if a trace exists for a given task ID\n   */\n  async checkTraceExists(taskId: string): Promise<boolean> {\n    try {\n      const response = await fetch(`/v1/traces/${encodeURIComponent(taskId)}`);\n      return response.ok;\n    } catch {\n      return false;\n    }\n  }\n\n  /**\n   * Get computed statistics for a task's trace\n   */\n  async fetchTraceStats(taskId: string): Promise<TraceStatsResponse> {\n    const response = await fetch(\n      `/v1/traces/${encodeURIComponent(taskId)}/stats`,\n    );\n    if (!response.ok) {\n      throw new Error(`Failed to fetch trace stats: ${response.status}`);\n    }\n    return (await response.json()) as TraceStatsResponse;\n  }\n\n  /**\n   * Delete traces by task IDs\n   */\n  async deleteTraces(\n    taskIds: string[],\n  ): Promise<{ deleted: string[]; notFound: string[] }> {\n    const response = await fetch(\"/v1/traces/delete\", {\n      method: \"POST\",\n      headers: { \"Content-Type\": \"application/json\" },\n      body: JSON.stringify({ taskIds }),\n    });\n    if (!response.ok) {\n      throw new Error(`Failed to delete traces: ${response.status}`);\n    }\n    return await response.json();\n  }\n\n  /**\n   * Get the URL for the raw trace file (for Perfetto)\n   */\n  getTraceRawUrl(taskId: string): string {\n    return `/v1/traces/${encodeURIComponent(taskId)}/raw`;\n  }\n}\n\nexport const appStore = new AppStore();\n\n// Reactive exports\nexport const hasStartedChat = () => appStore.hasStartedChat;\nexport const messages = () => appStore.messages;\nexport const currentResponse = () => appStore.currentResponse;\nexport const isLoading = () => appStore.isLoading;\nexport const ttftMs = () => appStore.ttftMs;\nexport const tps = () => appStore.tps;\nexport const totalTokens = () => appStore.totalTokens;\nexport const prefillProgress = () => appStore.prefillProgress;\nexport const topologyData = () => appStore.topologyData;\nexport const instances = () => appStore.instances;\nexport const runners = () => appStore.runners;\nexport const downloads = () => appStore.downloads;\nexport const nodeDisk = () => appStore.nodeDisk;\nexport const placementPreviews = () => appStore.placementPreviews;\nexport const selectedPreviewModelId = () => appStore.selectedPreviewModelId;\nexport const isLoadingPreviews = () => appStore.isLoadingPreviews;\nexport const lastUpdate = () => appStore.lastUpdate;\nexport const isTopologyMinimized = () => appStore.isTopologyMinimized;\nexport const selectedChatModel = () => appStore.selectedChatModel;\nexport const thinkingEnabled = () => appStore.thinkingEnabled;\nexport const debugMode = () => appStore.getDebugMode();\nexport const topologyOnlyMode = () => appStore.getTopologyOnlyMode();\nexport const chatSidebarVisible = () => appStore.getChatSidebarVisible();\n\n// Actions\nexport const stopGeneration = () => appStore.stopGeneration();\nexport const startChat = () => appStore.startChat();\nexport const sendMessage = (\n  content: string,\n  files?: {\n    id: string;\n    name: string;\n    type: string;\n    textContent?: string;\n    preview?: string;\n  }[],\n  enableThinking?: boolean | null,\n) => appStore.sendMessage(content, files, enableThinking);\nexport const generateImage = (prompt: string, modelId?: string) =>\n  appStore.generateImage(prompt, modelId);\nexport const editImage = (\n  prompt: string,\n  imageDataUrl: string,\n  modelId?: string,\n) => appStore.editImage(prompt, imageDataUrl, modelId);\nexport const editingImage = () => appStore.editingImage;\nexport const setEditingImage = (imageDataUrl: string, sourceMessage: Message) =>\n  appStore.setEditingImage(imageDataUrl, sourceMessage);\nexport const clearEditingImage = () => appStore.clearEditingImage();\nexport const clearChat = () => appStore.clearChat();\nexport const setSelectedChatModel = (modelId: string) =>\n  appStore.setSelectedModel(modelId);\nexport const selectPreviewModel = (modelId: string | null) =>\n  appStore.selectPreviewModel(modelId);\nexport const togglePreviewNodeFilter = (nodeId: string) =>\n  appStore.togglePreviewNodeFilter(nodeId);\nexport const clearPreviewNodeFilter = () => appStore.clearPreviewNodeFilter();\nexport const previewNodeFilter = () => appStore.previewNodeFilter;\nexport const deleteMessage = (messageId: string) =>\n  appStore.deleteMessage(messageId);\nexport const editMessage = (messageId: string, newContent: string) =>\n  appStore.editMessage(messageId, newContent);\nexport const editAndRegenerate = (messageId: string, newContent: string) =>\n  appStore.editAndRegenerate(messageId, newContent);\nexport const regenerateLastResponse = () => appStore.regenerateLastResponse();\nexport const regenerateFromToken = (messageId: string, tokenIndex: number) =>\n  appStore.regenerateFromToken(messageId, tokenIndex);\n\n// Conversation actions\nexport const conversations = () => appStore.conversations;\nexport const activeConversationId = () => appStore.activeConversationId;\nexport const createConversation = (name?: string) =>\n  appStore.createConversation(name);\nexport const loadConversation = (id: string) => appStore.loadConversation(id);\nexport const deleteConversation = (id: string) =>\n  appStore.deleteConversation(id);\nexport const deleteAllConversations = () => appStore.deleteAllConversations();\nexport const renameConversation = (id: string, name: string) =>\n  appStore.renameConversation(id, name);\nexport const getActiveConversation = () => appStore.getActiveConversation();\nexport const setConversationThinking = (enabled: boolean) =>\n  appStore.setConversationThinking(enabled);\n\n// Sidebar actions\nexport const isSidebarOpen = () => appStore.isSidebarOpen;\nexport const toggleSidebar = () => appStore.toggleSidebar();\nexport const toggleDebugMode = () => appStore.toggleDebugMode();\nexport const setDebugMode = (enabled: boolean) =>\n  appStore.setDebugMode(enabled);\nexport const toggleTopologyOnlyMode = () => appStore.toggleTopologyOnlyMode();\nexport const setTopologyOnlyMode = (enabled: boolean) =>\n  appStore.setTopologyOnlyMode(enabled);\nexport const toggleChatSidebarVisible = () =>\n  appStore.toggleChatSidebarVisible();\nexport const setChatSidebarVisible = (visible: boolean) =>\n  appStore.setChatSidebarVisible(visible);\n\n// Mobile sidebar state\nexport const mobileChatSidebarOpen = () => appStore.mobileChatSidebarOpen;\nexport const toggleMobileChatSidebar = () => appStore.toggleMobileChatSidebar();\nexport const setMobileChatSidebarOpen = (open: boolean) =>\n  appStore.setMobileChatSidebarOpen(open);\nexport const mobileRightSidebarOpen = () => appStore.mobileRightSidebarOpen;\nexport const toggleMobileRightSidebar = () =>\n  appStore.toggleMobileRightSidebar();\nexport const setMobileRightSidebarOpen = (open: boolean) =>\n  appStore.setMobileRightSidebarOpen(open);\n\nexport const refreshState = () => appStore.fetchState();\n\n// Connection status\nexport const isConnected = () => appStore.isConnected;\n\n// Node identities (for OS version mismatch detection)\nexport const nodeIdentities = () => appStore.nodeIdentities;\n\n// Thunderbolt & RDMA status\nexport const nodeThunderbolt = () => appStore.nodeThunderbolt;\nexport const nodeRdmaCtl = () => appStore.nodeRdmaCtl;\nexport const thunderboltBridgeCycles = () => appStore.thunderboltBridgeCycles;\nexport const nodeThunderboltBridge = () => appStore.nodeThunderboltBridge;\n\n// Image generation params\nexport const imageGenerationParams = () => appStore.getImageGenerationParams();\nexport const setImageGenerationParams = (\n  params: Partial<ImageGenerationParams>,\n) => appStore.setImageGenerationParams(params);\nexport const resetImageGenerationParams = () =>\n  appStore.resetImageGenerationParams();\n\n// Download actions\nexport const startDownload = (nodeId: string, shardMetadata: object) =>\n  appStore.startDownload(nodeId, shardMetadata);\nexport const deleteDownload = (nodeId: string, modelId: string) =>\n  appStore.deleteDownload(nodeId, modelId);\n\n// Trace actions\nexport const listTraces = () => appStore.listTraces();\nexport const checkTraceExists = (taskId: string) =>\n  appStore.checkTraceExists(taskId);\nexport const fetchTraceStats = (taskId: string) =>\n  appStore.fetchTraceStats(taskId);\nexport const getTraceRawUrl = (taskId: string) =>\n  appStore.getTraceRawUrl(taskId);\nexport const deleteTraces = (taskIds: string[]) =>\n  appStore.deleteTraces(taskIds);\n"
  },
  {
    "path": "dashboard/src/lib/stores/favorites.svelte.ts",
    "content": "/**\n * FavoritesStore - Manages favorite models with localStorage persistence\n */\n\nimport { browser } from \"$app/environment\";\n\nconst FAVORITES_KEY = \"exo-favorite-models\";\n\nclass FavoritesStore {\n  favorites = $state<Set<string>>(new Set());\n\n  constructor() {\n    if (browser) {\n      this.loadFromStorage();\n    }\n  }\n\n  private loadFromStorage() {\n    try {\n      const stored = localStorage.getItem(FAVORITES_KEY);\n      if (stored) {\n        const parsed = JSON.parse(stored) as string[];\n        this.favorites = new Set(parsed);\n      }\n    } catch (error) {\n      console.error(\"Failed to load favorites:\", error);\n    }\n  }\n\n  private saveToStorage() {\n    try {\n      const array = Array.from(this.favorites);\n      localStorage.setItem(FAVORITES_KEY, JSON.stringify(array));\n    } catch (error) {\n      console.error(\"Failed to save favorites:\", error);\n    }\n  }\n\n  add(baseModelId: string) {\n    const next = new Set(this.favorites);\n    next.add(baseModelId);\n    this.favorites = next;\n    this.saveToStorage();\n  }\n\n  remove(baseModelId: string) {\n    const next = new Set(this.favorites);\n    next.delete(baseModelId);\n    this.favorites = next;\n    this.saveToStorage();\n  }\n\n  toggle(baseModelId: string) {\n    if (this.favorites.has(baseModelId)) {\n      this.remove(baseModelId);\n    } else {\n      this.add(baseModelId);\n    }\n  }\n\n  isFavorite(baseModelId: string): boolean {\n    return this.favorites.has(baseModelId);\n  }\n\n  getAll(): string[] {\n    return Array.from(this.favorites);\n  }\n\n  getSet(): Set<string> {\n    return new Set(this.favorites);\n  }\n\n  hasAny(): boolean {\n    return this.favorites.size > 0;\n  }\n\n  clearAll() {\n    this.favorites = new Set();\n    this.saveToStorage();\n  }\n}\n\nexport const favoritesStore = new FavoritesStore();\n\nexport const favorites = () => favoritesStore.favorites;\nexport const hasFavorites = () => favoritesStore.hasAny();\nexport const isFavorite = (baseModelId: string) =>\n  favoritesStore.isFavorite(baseModelId);\nexport const toggleFavorite = (baseModelId: string) =>\n  favoritesStore.toggle(baseModelId);\nexport const addFavorite = (baseModelId: string) =>\n  favoritesStore.add(baseModelId);\nexport const removeFavorite = (baseModelId: string) =>\n  favoritesStore.remove(baseModelId);\nexport const getFavorites = () => favoritesStore.getAll();\nexport const getFavoritesSet = () => favoritesStore.getSet();\nexport const clearFavorites = () => favoritesStore.clearAll();\n"
  },
  {
    "path": "dashboard/src/lib/stores/recents.svelte.ts",
    "content": "/**\n * RecentsStore - Manages recently launched models with localStorage persistence\n */\n\nimport { browser } from \"$app/environment\";\n\nconst RECENTS_KEY = \"exo-recent-models\";\nconst MAX_RECENT_MODELS = 20;\n\ninterface RecentEntry {\n  modelId: string;\n  launchedAt: number;\n}\n\nclass RecentsStore {\n  recents = $state<RecentEntry[]>([]);\n\n  constructor() {\n    if (browser) {\n      this.loadFromStorage();\n    }\n  }\n\n  private loadFromStorage() {\n    try {\n      const stored = localStorage.getItem(RECENTS_KEY);\n      if (stored) {\n        const parsed = JSON.parse(stored) as RecentEntry[];\n        this.recents = parsed;\n      }\n    } catch (error) {\n      console.error(\"Failed to load recent models:\", error);\n    }\n  }\n\n  private saveToStorage() {\n    try {\n      localStorage.setItem(RECENTS_KEY, JSON.stringify(this.recents));\n    } catch (error) {\n      console.error(\"Failed to save recent models:\", error);\n    }\n  }\n\n  recordLaunch(modelId: string) {\n    // Remove existing entry for this model (if any) to move it to top\n    const filtered = this.recents.filter((r) => r.modelId !== modelId);\n    // Prepend new entry\n    const next = [{ modelId, launchedAt: Date.now() }, ...filtered];\n    // Cap at max\n    this.recents = next.slice(0, MAX_RECENT_MODELS);\n    this.saveToStorage();\n  }\n\n  getRecentModelIds(): string[] {\n    return this.recents.map((r) => r.modelId);\n  }\n\n  hasAny(): boolean {\n    return this.recents.length > 0;\n  }\n\n  clearAll() {\n    this.recents = [];\n    this.saveToStorage();\n  }\n}\n\nexport const recentsStore = new RecentsStore();\n\nexport const hasRecents = () => recentsStore.hasAny();\nexport const getRecentModelIds = () => recentsStore.getRecentModelIds();\nexport const getRecentEntries = () => recentsStore.recents;\nexport const recordRecentLaunch = (modelId: string) =>\n  recentsStore.recordLaunch(modelId);\nexport const clearRecents = () => recentsStore.clearAll();\n"
  },
  {
    "path": "dashboard/src/lib/stores/toast.svelte.ts",
    "content": "/**\n * Toast notification store - Global notification system for the EXO dashboard.\n *\n * Usage:\n *   import { addToast, dismissToast, toasts } from \"$lib/stores/toast.svelte\";\n *   addToast({ type: \"success\", message: \"Model launched\" });\n *   addToast({ type: \"error\", message: \"Connection lost\", persistent: true });\n */\n\ntype ToastType = \"success\" | \"error\" | \"warning\" | \"info\";\n\nexport interface Toast {\n  id: string;\n  type: ToastType;\n  message: string;\n  /** Auto-dismiss after this many ms. 0 = persistent (must be dismissed manually). */\n  duration: number;\n  createdAt: number;\n}\n\ninterface ToastInput {\n  type: ToastType;\n  message: string;\n  /** If true, toast stays until manually dismissed. Default: false. */\n  persistent?: boolean;\n  /** Auto-dismiss duration in ms. Default: 4000 for success/info, 6000 for error/warning. */\n  duration?: number;\n}\n\nconst DEFAULT_DURATIONS: Record<ToastType, number> = {\n  success: 4000,\n  info: 4000,\n  warning: 6000,\n  error: 6000,\n};\n\nlet toastList = $state<Toast[]>([]);\nconst timers = new Map<string, ReturnType<typeof setTimeout>>();\n\nfunction generateId(): string {\n  return `toast-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`;\n}\n\nexport function addToast(input: ToastInput): string {\n  const id = generateId();\n  const duration = input.persistent\n    ? 0\n    : (input.duration ?? DEFAULT_DURATIONS[input.type]);\n\n  const toast: Toast = {\n    id,\n    type: input.type,\n    message: input.message,\n    duration,\n    createdAt: Date.now(),\n  };\n\n  toastList = [...toastList, toast];\n\n  if (duration > 0) {\n    const timer = setTimeout(() => dismissToast(id), duration);\n    timers.set(id, timer);\n  }\n\n  return id;\n}\n\nexport function dismissToast(id: string): void {\n  const timer = timers.get(id);\n  if (timer) {\n    clearTimeout(timer);\n    timers.delete(id);\n  }\n  toastList = toastList.filter((t) => t.id !== id);\n}\n\n/** Dismiss all toasts matching a message (useful for dedup). */\nexport function dismissByMessage(message: string): void {\n  const matching = toastList.filter((t) => t.message === message);\n  for (const t of matching) {\n    dismissToast(t.id);\n  }\n}\n\nexport function toasts(): Toast[] {\n  return toastList;\n}\n"
  },
  {
    "path": "dashboard/src/lib/types/files.ts",
    "content": "/**\n * File attachment types for the chat interface\n */\n\nexport interface ChatUploadedFile {\n  id: string;\n  name: string;\n  size: number;\n  type: string;\n  file: File;\n  preview?: string;\n  textContent?: string;\n}\n\nexport interface ChatAttachment {\n  type: \"image\" | \"text\" | \"pdf\" | \"audio\";\n  name: string;\n  content?: string;\n  base64Url?: string;\n  mimeType?: string;\n}\n\nexport type FileCategory = \"image\" | \"text\" | \"pdf\" | \"audio\" | \"unknown\";\n\nexport const IMAGE_EXTENSIONS = [\n  \".jpg\",\n  \".jpeg\",\n  \".png\",\n  \".gif\",\n  \".webp\",\n  \".svg\",\n];\nexport const IMAGE_MIME_TYPES = [\n  \"image/jpeg\",\n  \"image/png\",\n  \"image/gif\",\n  \"image/webp\",\n  \"image/svg+xml\",\n];\n\nexport const TEXT_EXTENSIONS = [\n  \".txt\",\n  \".md\",\n  \".json\",\n  \".xml\",\n  \".yaml\",\n  \".yml\",\n  \".csv\",\n  \".log\",\n  \".js\",\n  \".ts\",\n  \".jsx\",\n  \".tsx\",\n  \".py\",\n  \".java\",\n  \".cpp\",\n  \".c\",\n  \".h\",\n  \".css\",\n  \".html\",\n  \".htm\",\n  \".sql\",\n  \".sh\",\n  \".bat\",\n  \".rs\",\n  \".go\",\n  \".rb\",\n  \".php\",\n  \".swift\",\n  \".kt\",\n  \".scala\",\n  \".r\",\n  \".dart\",\n  \".vue\",\n  \".svelte\",\n];\nexport const TEXT_MIME_TYPES = [\n  \"text/plain\",\n  \"text/markdown\",\n  \"text/csv\",\n  \"text/html\",\n  \"text/css\",\n  \"application/json\",\n  \"application/xml\",\n  \"text/xml\",\n  \"application/javascript\",\n  \"text/javascript\",\n  \"application/typescript\",\n];\n\nexport const PDF_EXTENSIONS = [\".pdf\"];\nexport const PDF_MIME_TYPES = [\"application/pdf\"];\n\nexport const AUDIO_EXTENSIONS = [\".mp3\", \".wav\", \".ogg\", \".m4a\"];\nexport const AUDIO_MIME_TYPES = [\n  \"audio/mpeg\",\n  \"audio/wav\",\n  \"audio/ogg\",\n  \"audio/mp4\",\n];\n\n/**\n * Get file category based on MIME type and extension\n */\nexport function getFileCategory(\n  mimeType: string,\n  fileName: string,\n): FileCategory {\n  const extension = fileName.toLowerCase().slice(fileName.lastIndexOf(\".\"));\n\n  if (\n    IMAGE_MIME_TYPES.includes(mimeType) ||\n    IMAGE_EXTENSIONS.includes(extension)\n  ) {\n    return \"image\";\n  }\n  if (PDF_MIME_TYPES.includes(mimeType) || PDF_EXTENSIONS.includes(extension)) {\n    return \"pdf\";\n  }\n  if (\n    AUDIO_MIME_TYPES.includes(mimeType) ||\n    AUDIO_EXTENSIONS.includes(extension)\n  ) {\n    return \"audio\";\n  }\n  if (\n    TEXT_MIME_TYPES.includes(mimeType) ||\n    TEXT_EXTENSIONS.includes(extension) ||\n    mimeType.startsWith(\"text/\")\n  ) {\n    return \"text\";\n  }\n  return \"unknown\";\n}\n\n/**\n * Get accept string for file input based on categories\n */\nexport function getAcceptString(categories: FileCategory[]): string {\n  const accepts: string[] = [];\n\n  for (const category of categories) {\n    switch (category) {\n      case \"image\":\n        accepts.push(...IMAGE_EXTENSIONS, ...IMAGE_MIME_TYPES);\n        break;\n      case \"text\":\n        accepts.push(...TEXT_EXTENSIONS, ...TEXT_MIME_TYPES);\n        break;\n      case \"pdf\":\n        accepts.push(...PDF_EXTENSIONS, ...PDF_MIME_TYPES);\n        break;\n      case \"audio\":\n        accepts.push(...AUDIO_EXTENSIONS, ...AUDIO_MIME_TYPES);\n        break;\n    }\n  }\n\n  return accepts.join(\",\");\n}\n\n/**\n * Format file size for display\n */\nexport function formatFileSize(bytes: number): string {\n  if (bytes === 0) return \"0 B\";\n  const k = 1024;\n  const sizes = [\"B\", \"KB\", \"MB\", \"GB\"];\n  const i = Math.floor(Math.log(bytes) / Math.log(k));\n  return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + \" \" + sizes[i];\n}\n\n/**\n * Read file as data URL (base64)\n */\nexport function readFileAsDataURL(file: File): Promise<string> {\n  return new Promise((resolve, reject) => {\n    const reader = new FileReader();\n    reader.onload = () => resolve(reader.result as string);\n    reader.onerror = () => reject(reader.error);\n    reader.readAsDataURL(file);\n  });\n}\n\n/**\n * Read file as text\n */\nexport function readFileAsText(file: File): Promise<string> {\n  return new Promise((resolve, reject) => {\n    const reader = new FileReader();\n    reader.onload = () => resolve(reader.result as string);\n    reader.onerror = () => reject(reader.error);\n    reader.readAsText(file);\n  });\n}\n\n/**\n * Process uploaded files into ChatUploadedFile format\n */\nexport async function processUploadedFiles(\n  files: File[],\n): Promise<ChatUploadedFile[]> {\n  const results: ChatUploadedFile[] = [];\n\n  for (const file of files) {\n    const id =\n      Date.now().toString() + Math.random().toString(36).substring(2, 9);\n    const category = getFileCategory(file.type, file.name);\n\n    const base: ChatUploadedFile = {\n      id,\n      name: file.name,\n      size: file.size,\n      type: file.type,\n      file,\n    };\n\n    try {\n      if (category === \"image\") {\n        const preview = await readFileAsDataURL(file);\n        results.push({ ...base, preview });\n      } else if (category === \"text\" || category === \"unknown\") {\n        const textContent = await readFileAsText(file);\n        results.push({ ...base, textContent });\n      } else if (category === \"pdf\") {\n        results.push(base);\n      } else if (category === \"audio\") {\n        const preview = await readFileAsDataURL(file);\n        results.push({ ...base, preview });\n      } else {\n        results.push(base);\n      }\n    } catch (error) {\n      console.error(\"Error processing file:\", file.name, error);\n      results.push(base);\n    }\n  }\n\n  return results;\n}\n"
  },
  {
    "path": "dashboard/src/lib/utils/downloads.ts",
    "content": "/**\n * Shared utilities for parsing and querying download state.\n *\n * The download state from `/state` is shaped as:\n *   Record<NodeId, Array<TaggedDownloadEntry>>\n *\n * Each entry is a tagged union object like:\n *   { \"DownloadCompleted\": { shard_metadata: { \"PipelineShardMetadata\": { model_card: { model_id: \"...\" }, ... } }, ... } }\n */\n\n/** Unwrap one level of tagged-union envelope, returning [tag, payload]. */\nfunction unwrapTagged(\n  obj: Record<string, unknown>,\n): [string, Record<string, unknown>] | null {\n  const keys = Object.keys(obj);\n  if (keys.length !== 1) return null;\n  const tag = keys[0];\n  const payload = obj[tag];\n  if (!payload || typeof payload !== \"object\") return null;\n  return [tag, payload as Record<string, unknown>];\n}\n\n/** Extract the model ID string from a download entry's nested shard_metadata. */\nexport function extractModelIdFromDownload(\n  downloadPayload: Record<string, unknown>,\n): string | null {\n  const shardMetadata =\n    downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;\n  if (!shardMetadata || typeof shardMetadata !== \"object\") return null;\n\n  const unwrapped = unwrapTagged(shardMetadata as Record<string, unknown>);\n  if (!unwrapped) return null;\n  const [, shardData] = unwrapped;\n\n  const modelMeta = shardData.model_card ?? shardData.modelCard;\n  if (!modelMeta || typeof modelMeta !== \"object\") return null;\n\n  const meta = modelMeta as Record<string, unknown>;\n  return (meta.model_id as string) ?? (meta.modelId as string) ?? null;\n}\n\n/** Extract the shard_metadata object from a download entry payload. */\nexport function extractShardMetadata(\n  downloadPayload: Record<string, unknown>,\n): Record<string, unknown> | null {\n  const shardMetadata =\n    downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;\n  if (!shardMetadata || typeof shardMetadata !== \"object\") return null;\n  return shardMetadata as Record<string, unknown>;\n}\n\n/** Get the download tag (DownloadCompleted, DownloadOngoing, etc.) from a wrapped entry. */\nexport function getDownloadTag(\n  entry: unknown,\n): [string, Record<string, unknown>] | null {\n  if (!entry || typeof entry !== \"object\") return null;\n  return unwrapTagged(entry as Record<string, unknown>);\n}\n\n/**\n * Iterate over all download entries for a given node, yielding [tag, payload, modelId].\n */\nfunction* iterNodeDownloads(\n  nodeDownloads: unknown[],\n): Generator<[string, Record<string, unknown>, string]> {\n  for (const entry of nodeDownloads) {\n    const tagged = getDownloadTag(entry);\n    if (!tagged) continue;\n    const [tag, payload] = tagged;\n    const modelId = extractModelIdFromDownload(payload);\n    if (!modelId) continue;\n    yield [tag, payload, modelId];\n  }\n}\n\n/** Check if a specific model is fully downloaded (DownloadCompleted) on a specific node. */\nexport function isModelDownloadedOnNode(\n  downloadsData: Record<string, unknown[]>,\n  nodeId: string,\n  modelId: string,\n): boolean {\n  const nodeDownloads = downloadsData[nodeId];\n  if (!Array.isArray(nodeDownloads)) return false;\n\n  for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {\n    if (tag === \"DownloadCompleted\" && entryModelId === modelId) return true;\n  }\n  return false;\n}\n\n/** Get all node IDs where a model is fully downloaded (DownloadCompleted). */\nexport function getNodesWithModelDownloaded(\n  downloadsData: Record<string, unknown[]>,\n  modelId: string,\n): string[] {\n  const result: string[] = [];\n  for (const nodeId of Object.keys(downloadsData)) {\n    if (isModelDownloadedOnNode(downloadsData, nodeId, modelId)) {\n      result.push(nodeId);\n    }\n  }\n  return result;\n}\n\n/**\n * Find shard metadata for a model from any download entry across all nodes.\n * Returns the first match found (completed entries are preferred).\n */\nexport function getShardMetadataForModel(\n  downloadsData: Record<string, unknown[]>,\n  modelId: string,\n): Record<string, unknown> | null {\n  let fallback: Record<string, unknown> | null = null;\n\n  for (const nodeDownloads of Object.values(downloadsData)) {\n    if (!Array.isArray(nodeDownloads)) continue;\n\n    for (const [tag, payload, entryModelId] of iterNodeDownloads(\n      nodeDownloads,\n    )) {\n      if (entryModelId !== modelId) continue;\n      const shard = extractShardMetadata(payload);\n      if (!shard) continue;\n\n      if (tag === \"DownloadCompleted\") return shard;\n      if (!fallback) fallback = shard;\n    }\n  }\n  return fallback;\n}\n\n/**\n * Get the download status tag for a specific model on a specific node.\n * Returns the \"best\" status: DownloadCompleted > DownloadOngoing > others.\n */\nexport function getModelDownloadStatus(\n  downloadsData: Record<string, unknown[]>,\n  nodeId: string,\n  modelId: string,\n): string | null {\n  const nodeDownloads = downloadsData[nodeId];\n  if (!Array.isArray(nodeDownloads)) return null;\n\n  let best: string | null = null;\n  for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {\n    if (entryModelId !== modelId) continue;\n    if (tag === \"DownloadCompleted\") return tag;\n    if (tag === \"DownloadOngoing\") best = tag;\n    else if (!best) best = tag;\n  }\n  return best;\n}\n"
  },
  {
    "path": "dashboard/src/routes/+layout.svelte",
    "content": "<script lang=\"ts\">\n  import \"../app.css\";\n  import ToastContainer from \"$lib/components/ToastContainer.svelte\";\n  import ConnectionBanner from \"$lib/components/ConnectionBanner.svelte\";\n\n  let { children } = $props();\n</script>\n\n<svelte:head>\n  <title>EXO</title>\n  <meta name=\"description\" content=\"EXO - Distributed AI Cluster Dashboard\" />\n</svelte:head>\n\n<div class=\"min-h-screen bg-background text-foreground\">\n  <ConnectionBanner />\n  {@render children?.()}\n  <ToastContainer />\n</div>\n"
  },
  {
    "path": "dashboard/src/routes/+page.svelte",
    "content": "<script lang=\"ts\">\n  import {\n    TopologyGraph,\n    ChatForm,\n    ChatMessages,\n    ChatSidebar,\n    ModelCard,\n    ModelPickerModal,\n    ChatModelSelector,\n  } from \"$lib/components\";\n  import {\n    pickAutoModel,\n    getAutoTierIndex,\n  } from \"$lib/components/ChatModelSelector.svelte\";\n  import {\n    favorites,\n    toggleFavorite,\n    getFavoritesSet,\n  } from \"$lib/stores/favorites.svelte\";\n  import {\n    hasRecents,\n    getRecentModelIds,\n    recordRecentLaunch,\n  } from \"$lib/stores/recents.svelte\";\n  import {\n    hasStartedChat,\n    isTopologyMinimized,\n    topologyData,\n    lastUpdate,\n    clearChat,\n    instances,\n    runners,\n    downloads,\n    placementPreviews,\n    selectedPreviewModelId,\n    isLoadingPreviews,\n    selectPreviewModel,\n    togglePreviewNodeFilter,\n    clearPreviewNodeFilter,\n    previewNodeFilter,\n    createConversation,\n    setSelectedChatModel,\n    selectedChatModel,\n    sendMessage,\n    generateImage,\n    editImage,\n    editingImage,\n    messages,\n    debugMode,\n    toggleDebugMode,\n    topologyOnlyMode,\n    toggleTopologyOnlyMode,\n    chatSidebarVisible,\n    toggleChatSidebarVisible,\n    mobileChatSidebarOpen,\n    toggleMobileChatSidebar,\n    setMobileChatSidebarOpen,\n    mobileRightSidebarOpen,\n    toggleMobileRightSidebar,\n    setMobileRightSidebarOpen,\n    nodeThunderbolt,\n    nodeRdmaCtl,\n    thunderboltBridgeCycles,\n    nodeThunderboltBridge,\n    nodeIdentities,\n    isConnected,\n    type DownloadProgress,\n    type PlacementPreview,\n  } from \"$lib/stores/app.svelte\";\n  import { addToast, dismissByMessage } from \"$lib/stores/toast.svelte\";\n  import HeaderNav from \"$lib/components/HeaderNav.svelte\";\n  import DeviceIcon from \"$lib/components/DeviceIcon.svelte\";\n  import { fade, fly, slide } from \"svelte/transition\";\n  import { tweened } from \"svelte/motion\";\n  import { cubicInOut, cubicOut } from \"svelte/easing\";\n  import { onMount } from \"svelte\";\n\n  const chatStarted = $derived(hasStartedChat());\n  const minimized = $derived(isTopologyMinimized());\n  const data = $derived(topologyData());\n  const update = $derived(lastUpdate());\n  const instanceData = $derived(instances());\n  const runnersData = $derived(runners());\n  const downloadsData = $derived(downloads());\n  const previewsData = $derived(placementPreviews());\n  const selectedModelId = $derived(selectedPreviewModelId());\n  const loadingPreviews = $derived(isLoadingPreviews());\n  const debugEnabled = $derived(debugMode());\n  const topologyOnlyEnabled = $derived(topologyOnlyMode());\n  const sidebarVisible = $derived(chatSidebarVisible());\n  const mobileChatOpen = $derived(mobileChatSidebarOpen());\n  const mobileRightOpen = $derived(mobileRightSidebarOpen());\n  const tbBridgeCycles = $derived(thunderboltBridgeCycles());\n  const tbBridgeData = $derived(nodeThunderboltBridge());\n  const identitiesData = $derived(nodeIdentities());\n  const tbIdentifiers = $derived(nodeThunderbolt());\n  const rdmaCtlData = $derived(nodeRdmaCtl());\n  const nodeFilter = $derived(previewNodeFilter());\n\n  // Aggregate active download progress across all instances for header indicator\n  const activeDownloadSummary = $derived.by(() => {\n    let totalBytes = 0;\n    let downloadedBytes = 0;\n    let count = 0;\n    for (const [id, inst] of Object.entries(instanceData)) {\n      const status = getInstanceDownloadStatus(id, inst);\n      if (status.isDownloading && status.progress) {\n        count++;\n        totalBytes += status.progress.totalBytes || 0;\n        downloadedBytes += status.progress.downloadedBytes || 0;\n      }\n    }\n    if (count === 0) return null;\n    return {\n      count,\n      percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,\n    };\n  });\n\n  // Detect macOS version mismatches across cluster nodes\n  const macosVersionMismatch = $derived.by(() => {\n    if (!identitiesData) return null;\n    const entries = Object.entries(identitiesData);\n    // Filter to macOS nodes (version starts with a digit, e.g. \"15.3\")\n    const macosNodes = entries.filter(([_, id]) => {\n      const v = id.osVersion;\n      return v && v !== \"Unknown\" && /^\\d/.test(v);\n    });\n    if (macosNodes.length < 2) return null;\n    // Compare on buildVersion for precise mismatch detection\n    const buildVersions = new Set(\n      macosNodes.map(([_, id]) => id.osBuildVersion ?? id.osVersion),\n    );\n    if (buildVersions.size <= 1) return null;\n    return macosNodes.map(([nodeId, id]) => ({\n      nodeId,\n      friendlyName: getNodeName(nodeId),\n      version: id.osVersion!,\n      buildVersion: id.osBuildVersion ?? \"Unknown\",\n    }));\n  });\n\n  // Detect TB5 nodes where RDMA is not enabled\n  const tb5WithoutRdma = $derived.by(() => {\n    const rdmaCtl = rdmaCtlData;\n    if (!rdmaCtl) return false;\n    const ids = tbIdentifiers;\n    if (!ids) return false;\n    // Find nodes with TB5 hardware (any TB interface)\n    const tb5NodeIds = Object.entries(ids)\n      .filter(([_, node]) => node.interfaces.length > 0)\n      .map(([id]) => id);\n    if (tb5NodeIds.length < 2) return false;\n    // At least one TB5 node has RDMA disabled\n    return tb5NodeIds.some((id) => rdmaCtl[id]?.enabled !== true);\n  });\n  let tb5InfoDismissed = $state(false);\n\n  // Detect Mac Studio nodes using RDMA on en2 (the port next to ethernet — RDMA doesn't work there)\n  const macStudioEn2RdmaWarning = $derived.by(() => {\n    const edges = data?.edges;\n    const ids = tbIdentifiers;\n    const rdmaCtl = rdmaCtlData;\n    if (!edges || !ids || !rdmaCtl) return null;\n\n    const affectedConnections: Array<{\n      nodeId: string;\n      nodeName: string;\n      peerNodeId: string;\n      peerNodeName: string;\n      rdmaIface: string;\n    }> = [];\n\n    const isMacStudio = (node: (typeof data.nodes)[string] | undefined) =>\n      node?.system_info?.model_id === \"Mac Studio\";\n\n    for (const edge of edges) {\n      if (!edge.sourceRdmaIface && !edge.sinkRdmaIface) continue;\n\n      const sourceNode = data?.nodes?.[edge.source];\n      if (\n        isMacStudio(sourceNode) &&\n        edge.sourceRdmaIface === \"rdma_en2\" &&\n        rdmaCtl[edge.source]?.enabled\n      ) {\n        affectedConnections.push({\n          nodeId: edge.source,\n          nodeName:\n            sourceNode?.friendly_name || edge.source.slice(0, 8) + \"...\",\n          peerNodeId: edge.target,\n          peerNodeName:\n            data?.nodes?.[edge.target]?.friendly_name ||\n            edge.target.slice(0, 8) + \"...\",\n          rdmaIface: \"en2\",\n        });\n      }\n\n      const sinkNode = data?.nodes?.[edge.target];\n      if (\n        isMacStudio(sinkNode) &&\n        edge.sinkRdmaIface === \"rdma_en2\" &&\n        rdmaCtl[edge.target]?.enabled\n      ) {\n        affectedConnections.push({\n          nodeId: edge.target,\n          nodeName: sinkNode?.friendly_name || edge.target.slice(0, 8) + \"...\",\n          peerNodeId: edge.source,\n          peerNodeName:\n            sourceNode?.friendly_name || edge.source.slice(0, 8) + \"...\",\n          rdmaIface: \"en2\",\n        });\n      }\n    }\n\n    // Deduplicate by nodeId\n    const seen = new Set<string>();\n    const unique = affectedConnections.filter((c) => {\n      if (seen.has(c.nodeId)) return false;\n      seen.add(c.nodeId);\n      return true;\n    });\n\n    return unique.length > 0 ? unique : null;\n  });\n  let macStudioEn2Dismissed = $state(false);\n\n  // Helper to get friendly node name from node ID\n  function getNodeName(nodeId: string): string {\n    const node = data?.nodes?.[nodeId];\n    return node?.friendly_name || nodeId.slice(0, 8) + \"...\";\n  }\n\n  // Helper to get the thunderbolt bridge service name from a cycle\n  function getTbBridgeServiceName(cycle: string[]): string {\n    // Try to find service name from any node in the cycle\n    for (const nodeId of cycle) {\n      const nodeData = tbBridgeData?.[nodeId];\n      if (nodeData?.serviceName) {\n        return nodeData.serviceName;\n      }\n    }\n    return \"Thunderbolt Bridge\"; // Fallback if no service name found\n  }\n\n  // Copy to clipboard state and function\n  let copiedCommand = $state(false);\n  async function copyToClipboard(text: string) {\n    try {\n      await navigator.clipboard.writeText(text);\n      copiedCommand = true;\n      setTimeout(() => {\n        copiedCommand = false;\n      }, 2000);\n    } catch (err) {\n      console.error(\"Failed to copy:\", err);\n    }\n  }\n\n  // Warning icon SVG path (reused across warning snippets)\n  const warningIconPath =\n    \"M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z\";\n  const infoIconPath =\n    \"M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z\";\n\n  let mounted = $state(false);\n  let localNodeId = $state<string | null>(null);\n\n  // ── Onboarding wizard state ──\n  const ONBOARDING_COMPLETE_KEY = \"exo-onboarding-complete\";\n  let onboardingStep = $state(0); // 0 = not in onboarding, 1-9 = wizard steps\n  let onboardingModelId = $state<string | null>(null); // model selected during onboarding\n  let onboardingFadingOut = $state(false); // true during fade-out transition\n  const showOnboarding = $derived(onboardingStep > 0);\n  const showOnboardingOverlay = $derived(showOnboarding || onboardingFadingOut);\n\n  // ── Steps 1-5 animation state: cinematic SVG story ──\n  const SIMULATED_STUDIO_GB = 256; // simulated Mac Studio memory\n  const onboardingCombinedGB = $derived(\n    userDeviceInfo.memoryGB + SIMULATED_STUDIO_GB,\n  );\n\n  // Models unlocked by adding the second device — one per base model, well-known preferred\n  const unlockedModels = $derived.by(() => {\n    if (models.length === 0) return [];\n    const singleGB = userDeviceInfo.memoryGB;\n    const combinedGB = onboardingCombinedGB;\n    const candidates = models\n      .filter((m) => {\n        const sizeGB = getModelSizeGB(m);\n        return sizeGB > singleGB && sizeGB <= combinedGB && m.family;\n      })\n      .sort((a, b) => getModelSizeGB(a) - getModelSizeGB(b));\n    // Deduplicate by base_model (or family as fallback) — keep smallest quant per base\n    const seen = new Set<string>();\n    const deduped: typeof candidates = [];\n    for (const m of candidates) {\n      const key = m.base_model || m.family || m.id;\n      if (seen.has(key)) continue;\n      seen.add(key);\n      deduped.push(m);\n    }\n    return deduped.slice(0, 3);\n  });\n\n  // User device info from topology — uses /node_id to find our own node\n  const userDeviceInfo = $derived.by(() => {\n    if (!data || Object.keys(data.nodes).length === 0) {\n      return { name: \"MacBook Pro\", memoryGB: 36, deviceType: \"macbook pro\" };\n    }\n    const ourNode = localNodeId ? data.nodes[localNodeId] : undefined;\n    const node = ourNode ?? Object.values(data.nodes)[0];\n    const totalMem =\n      node.macmon_info?.memory?.ram_total ?? node.system_info?.memory ?? 0;\n    const memGB = Math.round(totalMem / (1024 * 1024 * 1024));\n    const name = node.friendly_name || \"Your Mac\";\n    const modelId = (node.system_info?.model_id || \"macbook pro\").toLowerCase();\n    return { name, memoryGB: memGB || 36, deviceType: modelId };\n  });\n\n  let showContinueButton = $state(false);\n  let stepTitle = $state(\"\");\n  let stepTransitioning = $state(false);\n\n  // Advance to the next onboarding step\n  function advanceStep(target: number) {\n    showContinueButton = false;\n    if (target <= 5) {\n      // Steps 1-5 share a persistent SVG canvas — just set the step directly\n      onboardingStep = target;\n    } else {\n      // Leaving the cinematic sequence — fade out, then switch\n      stepTransitioning = true;\n      setTimeout(() => {\n        onboardingStep = target;\n        stepTransitioning = false;\n      }, 350);\n    }\n  }\n\n  // Tweened animation values for the persistent SVG canvas\n  const device1X = tweened(350, { duration: 800, easing: cubicInOut });\n  const device2X = tweened(550, { duration: 800, easing: cubicInOut });\n  const device2Opacity = tweened(0, { duration: 600, easing: cubicOut });\n  const connectionOpacity = tweened(0, { duration: 500, easing: cubicOut });\n  const connectionIsRed = tweened(0, { duration: 500, easing: cubicOut }); // 0=gold, 1=red\n  const combinedLabelOpacity = tweened(0, { duration: 500, easing: cubicOut });\n  const modelBlockY = tweened(20, { duration: 700, easing: cubicInOut });\n  const modelBlockOpacity = tweened(0, { duration: 500, easing: cubicOut });\n  const modelSplitProgress = tweened(0, { duration: 800, easing: cubicInOut }); // 0=unified, 1=fully split\n  const disconnectXOpacity = tweened(0, { duration: 400, easing: cubicOut });\n  const device1Opacity = tweened(1, { duration: 600, easing: cubicOut });\n  const logoOpacity = tweened(1, { duration: 600, easing: cubicOut });\n  // Step 2 chip fade: 0→N where each chip fades in at its stagger offset\n  const chipPhase = tweened(0, { duration: 800, easing: cubicOut });\n  const deviceCountOpacity = tweened(0, { duration: 600, easing: cubicOut });\n  const topologyOpacity = tweened(1, { duration: 400, easing: cubicOut });\n  const titleOpacity = tweened(0, { duration: 500, easing: cubicOut });\n  const subtitleOpacity = tweened(0, { duration: 500, easing: cubicOut });\n\n  // ── Step 1: \"Your EXO Network\" — show real topology ──\n  $effect(() => {\n    if (onboardingStep === 1) {\n      showContinueButton = false;\n      stepTitle = \"\";\n      // Reset all tweens to initial\n      device1X.set(350, { duration: 0 });\n      device1Opacity.set(0, { duration: 0 });\n      device2Opacity.set(0, { duration: 0 });\n      connectionOpacity.set(0, { duration: 0 });\n      connectionIsRed.set(0, { duration: 0 });\n      combinedLabelOpacity.set(0, { duration: 0 });\n      modelBlockOpacity.set(0, { duration: 0 });\n      modelSplitProgress.set(0, { duration: 0 });\n      disconnectXOpacity.set(0, { duration: 0 });\n      logoOpacity.set(1, { duration: 0 });\n      titleOpacity.set(0, { duration: 0 });\n      subtitleOpacity.set(0, { duration: 0 });\n      chipPhase.set(0, { duration: 0 });\n      deviceCountOpacity.set(0, { duration: 0 });\n      topologyOpacity.set(1, { duration: 0 });\n\n      const t1 = setTimeout(() => {\n        titleOpacity.set(1);\n      }, 300);\n      const t2 = setTimeout(() => {\n        deviceCountOpacity.set(1);\n      }, 800);\n      const t3 = setTimeout(() => {\n        showContinueButton = true;\n      }, 1200);\n\n      return () => {\n        clearTimeout(t1);\n        clearTimeout(t2);\n        clearTimeout(t3);\n      };\n    }\n  });\n\n  // ── Step 2: \"Add devices to run larger models\" — cross-fade topology out, device pair animates in ──\n  $effect(() => {\n    if (onboardingStep === 2) {\n      showContinueButton = false;\n\n      // Cross-fade: fade out real topology\n      topologyOpacity.set(0);\n\n      // Immediately transition out step 1 elements\n      logoOpacity.set(0);\n      deviceCountOpacity.set(0);\n      // Smoothly crossfade the title: fade old out, update text, fade new in\n      titleOpacity.set(0, { duration: 300 });\n      subtitleOpacity.set(0, { duration: 0 });\n\n      // Delay all step 2 animations by 400ms to let topology fade out\n      const DELAY = 400;\n\n      const t0 = setTimeout(() => {\n        stepTitle = \"Add devices to run larger models\";\n        titleOpacity.set(1, { duration: 400 });\n      }, DELAY + 300);\n\n      const t1 = setTimeout(() => {\n        device1Opacity.set(1, { duration: 0 });\n        device1X.set(220);\n        device2X.set(480, { duration: 0 });\n        device2Opacity.set(0, { duration: 0 });\n      }, DELAY + 200);\n      const t2 = setTimeout(() => {\n        device2Opacity.set(1);\n        device2X.set(480);\n      }, DELAY + 700);\n      const t3 = setTimeout(() => {\n        connectionOpacity.set(1);\n      }, DELAY + 1200);\n      const t4 = setTimeout(() => {\n        combinedLabelOpacity.set(1);\n      }, DELAY + 1600);\n      // Staggered chip fade-in (each chip offsets by 0.6 in chipPhase)\n      const t5 = setTimeout(() => {\n        chipPhase.set(3, { duration: 1800 });\n      }, DELAY + 1800);\n      const t6 = setTimeout(() => {\n        showContinueButton = true;\n      }, DELAY + 3200);\n\n      return () => {\n        clearTimeout(t0);\n        clearTimeout(t1);\n        clearTimeout(t2);\n        clearTimeout(t3);\n        clearTimeout(t4);\n        clearTimeout(t5);\n        clearTimeout(t6);\n      };\n    }\n  });\n\n  // ── Step 3: \"exo splits the model\" — model block appears, splits ──\n  $effect(() => {\n    if (onboardingStep === 3) {\n      showContinueButton = false;\n      // Gently fade out the unlock chips\n      chipPhase.set(0, { duration: 600 });\n\n      // Crossfade title\n      titleOpacity.set(0, { duration: 250 });\n      subtitleOpacity.set(0, { duration: 250 });\n      setTimeout(() => {\n        stepTitle = \"exo splits models across devices\";\n        titleOpacity.set(1, { duration: 400 });\n        subtitleOpacity.set(1, { duration: 400 });\n      }, 250);\n\n      // Wait for chips to fade before showing model block\n      const t1 = setTimeout(() => {\n        modelBlockOpacity.set(1);\n        modelBlockY.set(50);\n      }, 600);\n      const t2 = setTimeout(() => {\n        modelSplitProgress.set(1);\n      }, 1500);\n      const t3 = setTimeout(() => {\n        showContinueButton = true;\n      }, 2300);\n\n      return () => {\n        clearTimeout(t1);\n        clearTimeout(t2);\n        clearTimeout(t3);\n      };\n    }\n  });\n\n  // ── Step 4: \"A device disconnects... exo self-heals\" — full disconnect+heal sequence ──\n  $effect(() => {\n    if (onboardingStep === 4) {\n      showContinueButton = false;\n\n      // Crossfade title\n      titleOpacity.set(0, { duration: 250 });\n      subtitleOpacity.set(0, { duration: 250 });\n      setTimeout(() => {\n        stepTitle = \"When a device disconnects...\";\n        titleOpacity.set(1, { duration: 400 });\n        subtitleOpacity.set(1, { duration: 400 });\n      }, 250);\n\n      // Phase 1: Disconnect\n      const t1 = setTimeout(() => {\n        connectionIsRed.set(1);\n      }, 400);\n      const t2 = setTimeout(() => {\n        disconnectXOpacity.set(1);\n      }, 800);\n      const t3 = setTimeout(() => {\n        device2Opacity.set(0);\n        connectionOpacity.set(0);\n        disconnectXOpacity.set(0);\n        combinedLabelOpacity.set(0);\n      }, 1600);\n\n      // Phase 2: Self-heal — crossfade title + subtitle\n      const t4 = setTimeout(() => {\n        titleOpacity.set(0, { duration: 250 });\n        subtitleOpacity.set(0, { duration: 250 });\n      }, 2550);\n      const t4b = setTimeout(() => {\n        stepTitle = \"exo self-heals\";\n        titleOpacity.set(1, { duration: 400 });\n        subtitleOpacity.set(1, { duration: 400 });\n      }, 2800);\n      const t5 = setTimeout(() => {\n        device1X.set(350);\n        device2X.set(350);\n      }, 3100);\n      const t6 = setTimeout(() => {\n        modelSplitProgress.set(0);\n        modelBlockY.set(20); // Lift up while merging\n        connectionIsRed.set(0);\n      }, 3700);\n      const t7 = setTimeout(() => {\n        modelBlockY.set(125); // Settle back down just above the device\n      }, 4800);\n      const t8 = setTimeout(() => {\n        advanceStep(6);\n      }, 6200);\n\n      return () => {\n        clearTimeout(t1);\n        clearTimeout(t2);\n        clearTimeout(t3);\n        clearTimeout(t4);\n        clearTimeout(t4b);\n        clearTimeout(t5);\n        clearTimeout(t6);\n        clearTimeout(t7);\n        clearTimeout(t8);\n      };\n    }\n  });\n\n  // Recommended models for onboarding: 2 large, 2 medium, 2 small\n  // Always includes Llama-3.2-3B-4bit as a fast-loading small option\n  const PINNED_ONBOARDING_MODEL = \"mlx-community/Llama-3.2-3B-Instruct-4bit\";\n  const onboardingModels = $derived.by(() => {\n    if (models.length === 0) return [];\n    const sorted = [...models]\n      .filter((m) => hasEnoughMemory(m) && getModelSizeGB(m) > 0)\n      .sort((a, b) => getModelSizeGB(b) - getModelSizeGB(a));\n    if (sorted.length <= 6) return sorted;\n\n    // Split into thirds by size: large (top third), medium (middle), small (bottom)\n    const third = Math.max(1, Math.floor(sorted.length / 3));\n    const large = sorted.slice(0, third);\n    const medium = sorted.slice(third, third * 2);\n    const small = sorted.slice(third * 2);\n\n    // Pick 2 from each tier, ensuring pinned model counts as a small pick\n    const pinned =\n      small.find((m) => m.id === PINNED_ONBOARDING_MODEL) ||\n      sorted.find((m) => m.id === PINNED_ONBOARDING_MODEL);\n    const pickLarge = large.slice(0, 2);\n    const pickMedium = medium.slice(0, 2);\n    const pickSmall = pinned\n      ? [\n          small.find((m) => m.id !== PINNED_ONBOARDING_MODEL) || small[0],\n          pinned,\n        ].filter(Boolean)\n      : small.slice(0, 2);\n\n    const result = [...pickLarge, ...pickMedium, ...pickSmall];\n    // Deduplicate (in case pinned was already picked)\n    const seen = new Set<string>();\n    return result.filter((m) => {\n      if (seen.has(m.id)) return false;\n      seen.add(m.id);\n      return true;\n    });\n  });\n\n  // Track onboarding instance status for auto-advancing steps.\n  // Uses runner status as source of truth to avoid false \"ready\" from missing download data.\n  // Only tracks the specific model launched during onboarding (ignores other running instances).\n  $effect(() => {\n    if (onboardingStep === 7 && instanceCount > 0 && onboardingModelId) {\n      let anyDownloading = false;\n      let anyReady = false;\n      for (const [id, inst] of Object.entries(instanceData)) {\n        // Only check instances for the model we launched during onboarding\n        if (getInstanceModelId(inst) !== onboardingModelId) continue;\n        const runnerStatus = deriveInstanceStatus(inst);\n        if (\n          runnerStatus.statusText === \"READY\" ||\n          runnerStatus.statusText === \"LOADED\" ||\n          runnerStatus.statusText === \"RUNNING\"\n        ) {\n          anyReady = true;\n        } else if (runnerStatus.statusText === \"DOWNLOADING\") {\n          anyDownloading = true;\n        } else {\n          const dlStatus = getInstanceDownloadStatus(id, inst);\n          if (dlStatus.isDownloading) anyDownloading = true;\n        }\n      }\n      // Model already cached & ready — skip download AND loading steps\n      if (anyReady) {\n        onboardingStep = 9;\n      } else if (anyDownloading) {\n        // Stay on step 7 (downloading)\n      } else {\n        // Not ready and not downloading — could be loading, initializing, or preparing.\n        // Only advance to step 8 if runners are actually in a loading state.\n        for (const [, inst] of Object.entries(instanceData)) {\n          if (getInstanceModelId(inst) !== onboardingModelId) continue;\n          const runnerStatus = deriveInstanceStatus(inst);\n          if (\n            runnerStatus.statusText === \"LOADING\" ||\n            runnerStatus.statusText === \"WARMING UP\"\n          ) {\n            onboardingStep = 8;\n            break;\n          }\n        }\n      }\n    }\n  });\n\n  $effect(() => {\n    if (onboardingStep === 8 && instanceCount > 0 && onboardingModelId) {\n      for (const [, inst] of Object.entries(instanceData)) {\n        if (getInstanceModelId(inst) !== onboardingModelId) continue;\n        const runnerStatus = deriveInstanceStatus(inst);\n        if (\n          runnerStatus.statusText === \"READY\" ||\n          runnerStatus.statusText === \"LOADED\" ||\n          runnerStatus.statusText === \"RUNNING\"\n        ) {\n          onboardingStep = 9;\n          break;\n        }\n      }\n    }\n  });\n\n  function completeOnboarding() {\n    // Trigger fade-out, then fully remove overlay\n    onboardingFadingOut = true;\n    onboardingStep = 0;\n    try {\n      localStorage.setItem(ONBOARDING_COMPLETE_KEY, \"true\");\n    } catch {\n      // ignore\n    }\n    // Persist to server (~/.exo)\n    fetch(\"/onboarding\", { method: \"POST\" }).catch(() => {});\n    // Remove overlay after fade-out transition completes\n    setTimeout(() => {\n      onboardingFadingOut = false;\n    }, 500);\n  }\n\n  // Auto-complete onboarding when user sends a message from step 9\n  $effect(() => {\n    if (onboardingStep === 9 && chatStarted) {\n      completeOnboarding();\n    }\n  });\n\n  let onboardingError = $state<string | null>(null);\n\n  async function onboardingLaunchModel(modelId: string) {\n    onboardingModelId = modelId;\n    onboardingError = null;\n    selectPreviewModel(modelId);\n    onboardingStep = 7;\n    // Launch via standard placement API (same as main dashboard)\n    // Single-node: force Pipeline/Ring regardless of persisted defaults\n    const nodeCount = topologyData()\n      ? Object.keys(topologyData()!.nodes).length\n      : 1;\n    const sharding = nodeCount <= 1 ? \"Pipeline\" : selectedSharding;\n    const instanceType = nodeCount <= 1 ? \"MlxRing\" : selectedInstanceType;\n    try {\n      const placementResponse = await fetch(\n        `/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${sharding}&instance_meta=${instanceType}&min_nodes=1`,\n      );\n      if (!placementResponse.ok) {\n        const errorText = await placementResponse.text();\n        onboardingError = `Failed to get placement: ${errorText}`;\n        onboardingStep = 6;\n        return;\n      }\n      const instanceData = await placementResponse.json();\n      const response = await fetch(\"/instance\", {\n        method: \"POST\",\n        headers: { \"Content-Type\": \"application/json\" },\n        body: JSON.stringify({ instance: instanceData }),\n      });\n      if (!response.ok) {\n        const errorText = await response.text();\n        onboardingError = `Failed to launch: ${errorText}`;\n        onboardingStep = 6;\n        return;\n      }\n      setSelectedChatModel(modelId);\n      recordRecentLaunch(modelId);\n    } catch (error) {\n      onboardingError = `Network error: ${error}`;\n      onboardingStep = 6;\n    }\n  }\n\n  // Helper to get onboarding download progress\n  const onboardingDownloadProgress = $derived.by(() => {\n    if (instanceCount === 0) return null;\n    for (const [id, inst] of Object.entries(instanceData)) {\n      const status = getInstanceDownloadStatus(id, inst);\n      if (status.isDownloading && status.progress) {\n        return status.progress;\n      }\n    }\n    return null;\n  });\n\n  // Helper to get onboarding model loading progress (layers loaded)\n  const onboardingLoadProgress = $derived.by(() => {\n    if (instanceCount === 0 || !onboardingModelId) return null;\n    let layersLoaded = 0,\n      totalLayers = 0;\n    for (const [, inst] of Object.entries(instanceData)) {\n      if (getInstanceModelId(inst) !== onboardingModelId) continue;\n      const status = deriveInstanceStatus(inst);\n      if (\n        status.statusText === \"LOADING\" &&\n        status.totalLayers &&\n        status.totalLayers > 0\n      ) {\n        layersLoaded += status.layersLoaded ?? 0;\n        totalLayers += status.totalLayers;\n      }\n    }\n    if (totalLayers === 0) return null;\n    return {\n      layersLoaded,\n      totalLayers,\n      percentage: (layersLoaded / totalLayers) * 100,\n    };\n  });\n\n  // Instance launch state\n  let models = $state<\n    Array<{\n      id: string;\n      name?: string;\n      storage_size_megabytes?: number;\n      tasks?: string[];\n      hugging_face_id?: string;\n      is_custom?: boolean;\n      family?: string;\n      quantization?: string;\n      base_model?: string;\n      capabilities?: string[];\n    }>\n  >([]);\n  type ModelMemoryFitStatus =\n    | \"fits_now\"\n    | \"fits_cluster_capacity\"\n    | \"too_large\";\n\n  // Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs\n  const modelTasks = $derived(() => {\n    const tasks: Record<string, string[]> = {};\n    for (const model of models) {\n      if (model.tasks && model.tasks.length > 0) {\n        // Map by short ID\n        tasks[model.id] = model.tasks;\n        // Also map by hugging_face_id from the API response\n        if (model.hugging_face_id) {\n          tasks[model.hugging_face_id] = model.tasks;\n        }\n      }\n    }\n    return tasks;\n  });\n\n  const modelCapabilities = $derived(() => {\n    const caps: Record<string, string[]> = {};\n    for (const model of models) {\n      if (model.capabilities && model.capabilities.length > 0) {\n        caps[model.id] = model.capabilities;\n        if (model.hugging_face_id) {\n          caps[model.hugging_face_id] = model.capabilities;\n        }\n      }\n    }\n    return caps;\n  });\n\n  // Helper to check if a model supports image generation\n  function modelSupportsImageGeneration(modelId: string): boolean {\n    const model = models.find(\n      (m) => m.id === modelId || m.hugging_face_id === modelId,\n    );\n    if (!model?.tasks) return false;\n    return (\n      model.tasks.includes(\"TextToImage\") ||\n      model.tasks.includes(\"ImageToImage\")\n    );\n  }\n\n  // Helper to check if a model supports image editing\n  function modelSupportsImageEditing(modelId: string): boolean {\n    const model = models.find(\n      (m) => m.id === modelId || m.hugging_face_id === modelId,\n    );\n    if (!model?.tasks) return false;\n    return model.tasks.includes(\"ImageToImage\");\n  }\n\n  // Route a message to the correct endpoint based on model capabilities.\n  // Image models go to generateImage/editImage; text models go to sendMessage.\n  function routeMessage(\n    content: string,\n    files?: {\n      id: string;\n      name: string;\n      type: string;\n      textContent?: string;\n      preview?: string;\n    }[],\n  ) {\n    const model = selectedChatModel();\n    if (!model) {\n      sendMessage(content, files, null);\n      return;\n    }\n\n    const currentEditImage = editingImage();\n\n    // Image editing mode (explicit edit or attached image with ImageToImage model)\n    if (currentEditImage && content && modelSupportsImageEditing(model)) {\n      editImage(content, currentEditImage.imageDataUrl);\n      return;\n    }\n    if (\n      modelSupportsImageEditing(model) &&\n      files?.length &&\n      files[0].preview &&\n      content\n    ) {\n      editImage(content, files[0].preview);\n      return;\n    }\n\n    // Text-to-image generation\n    if (modelSupportsImageGeneration(model) && content) {\n      generateImage(content);\n      return;\n    }\n\n    // Default: text chat\n    sendMessage(content, files, null);\n  }\n\n  let selectedSharding = $state<\"Pipeline\" | \"Tensor\">(\"Pipeline\");\n  type InstanceMeta = \"MlxRing\" | \"MlxJaccl\";\n\n  // Launch defaults persistence\n  const LAUNCH_DEFAULTS_KEY = \"exo-launch-defaults-v2\";\n  interface LaunchDefaults {\n    modelId: string | null;\n    sharding: \"Pipeline\" | \"Tensor\";\n    instanceType: InstanceMeta;\n    minNodes: number;\n  }\n\n  function saveLaunchDefaults(): void {\n    const defaults: LaunchDefaults = {\n      modelId: selectedPreviewModelId(),\n      sharding: selectedSharding,\n      instanceType: selectedInstanceType,\n      minNodes: selectedMinNodes,\n    };\n    try {\n      localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));\n    } catch (e) {\n      console.warn(\"Failed to save launch defaults:\", e);\n    }\n  }\n\n  function loadLaunchDefaults(): LaunchDefaults | null {\n    try {\n      const stored = localStorage.getItem(LAUNCH_DEFAULTS_KEY);\n      if (!stored) return null;\n      return JSON.parse(stored) as LaunchDefaults;\n    } catch (e) {\n      console.warn(\"Failed to load launch defaults:\", e);\n      return null;\n    }\n  }\n\n  function applyLaunchDefaults(\n    availableModels: Array<{ id: string }>,\n    maxNodes: number,\n  ): void {\n    const defaults = loadLaunchDefaults();\n    if (!defaults) return;\n\n    // Apply sharding and instance type unconditionally\n    selectedSharding = defaults.sharding;\n    selectedInstanceType =\n      defaults.instanceType === \"MlxRing\" ? \"MlxRing\" : \"MlxJaccl\";\n\n    // Apply minNodes if valid (between 1 and maxNodes)\n    if (\n      defaults.minNodes &&\n      defaults.minNodes >= 1 &&\n      defaults.minNodes <= maxNodes\n    ) {\n      selectedMinNodes = defaults.minNodes;\n    }\n\n    // Only apply model if it exists in the available models\n    if (\n      defaults.modelId &&\n      availableModels.some((m) => m.id === defaults.modelId)\n    ) {\n      selectPreviewModel(defaults.modelId);\n      setSelectedChatModel(defaults.modelId);\n    }\n  }\n\n  let selectedInstanceType = $state<InstanceMeta>(\"MlxRing\");\n  let selectedMinNodes = $state<number>(1);\n  let minNodesInitialized = $state(false);\n  let launchingModelId = $state<string | null>(null);\n  let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());\n\n  // Model picker modal state\n  let isModelPickerOpen = $state(false);\n\n  // Advanced options toggle (hides technical jargon for new users)\n  let showAdvancedOptions = $state(false);\n\n  // Favorites state (reactive)\n  const favoritesSet = $derived(getFavoritesSet());\n\n  // Recent models state (reactive)\n  const recentModelIds = $derived(getRecentModelIds());\n  const showRecentsTab = $derived(hasRecents());\n\n  // Slider dragging state\n  let isDraggingSlider = $state(false);\n  let sliderTrackElement: HTMLDivElement | null = $state(null);\n\n  // Instances container ref for scrolling\n  let instancesContainerRef: HTMLDivElement | null = $state(null);\n  // Chat scroll container ref for precise scroll behavior\n  let chatScrollRef: HTMLDivElement | null = $state(null);\n\n  // Instance hover state for highlighting nodes in topology\n  let hoveredInstanceId = $state<string | null>(null);\n\n  // Preview card hover state for highlighting nodes in topology\n  let hoveredPreviewNodes = $state<Set<string>>(new Set());\n\n  // Computed: Check if filter is active (from store)\n  const isFilterActive = $derived(() => nodeFilter.size > 0);\n\n  // Helper to unwrap tagged instance for hover highlighting\n  function unwrapInstanceNodes(instanceWrapped: unknown): Set<string> {\n    if (!instanceWrapped || typeof instanceWrapped !== \"object\")\n      return new Set();\n    const keys = Object.keys(instanceWrapped as Record<string, unknown>);\n    if (keys.length !== 1) return new Set();\n    const instance = (instanceWrapped as Record<string, unknown>)[keys[0]];\n    if (!instance || typeof instance !== \"object\") return new Set();\n    const inst = instance as {\n      shardAssignments?: { nodeToRunner?: Record<string, string> };\n    };\n    if (!inst.shardAssignments?.nodeToRunner) return new Set();\n    return new Set(Object.keys(inst.shardAssignments.nodeToRunner));\n  }\n\n  function toggleInstanceDownloadDetails(nodeId: string): void {\n    const next = new Set(instanceDownloadExpandedNodes);\n    if (next.has(nodeId)) {\n      next.delete(nodeId);\n    } else {\n      next.add(nodeId);\n    }\n    instanceDownloadExpandedNodes = next;\n  }\n\n  // Compute highlighted nodes from hovered instance or hovered preview\n  const highlightedNodes = $derived(() => {\n    // First check instance hover\n    if (hoveredInstanceId) {\n      const instanceWrapped = instanceData[hoveredInstanceId];\n      return unwrapInstanceNodes(instanceWrapped);\n    }\n    // Then check preview hover\n    if (hoveredPreviewNodes.size > 0) {\n      return hoveredPreviewNodes;\n    }\n    return new Set<string>();\n  });\n\n  // Helper to estimate memory from model ID (mirrors ModelCard logic)\n  // Uses regex with word boundaries to avoid false matches like '4bit' matching '4b'\n  function estimateMemoryGB(modelId: string, modelName?: string): number {\n    // Check both ID and name for quantization info\n    const combined = `${modelId} ${modelName || \"\"}`.toLowerCase();\n\n    // Detect quantization level - affects memory by roughly 2x between levels\n    const is4bit =\n      combined.includes(\"4bit\") ||\n      combined.includes(\"4-bit\") ||\n      combined.includes(\":4bit\");\n    const is8bit =\n      combined.includes(\"8bit\") ||\n      combined.includes(\"8-bit\") ||\n      combined.includes(\":8bit\");\n    // 4-bit = 0.5 bytes/param, 8-bit = 1 byte/param, fp16 = 2 bytes/param\n    const quantMultiplier = is4bit ? 0.5 : is8bit ? 1 : 2;\n    const id = modelId.toLowerCase();\n\n    // Known large models that don't follow the standard naming pattern\n    // DeepSeek V3 has 685B parameters\n    if (id.includes(\"deepseek-v3\")) {\n      return Math.round(685 * quantMultiplier);\n    }\n    // DeepSeek V2 has 236B parameters\n    if (id.includes(\"deepseek-v2\")) {\n      return Math.round(236 * quantMultiplier);\n    }\n    // Llama 4 Scout/Maverick are large models\n    if (id.includes(\"llama-4\")) {\n      return Math.round(400 * quantMultiplier);\n    }\n\n    // Match parameter counts with word boundaries (e.g., \"70b\" but not \"4bit\")\n    const paramMatch = id.match(/(\\d+(?:\\.\\d+)?)\\s*b(?![a-z])/i);\n    if (paramMatch) {\n      const params = parseFloat(paramMatch[1]);\n      return Math.max(4, Math.round(params * quantMultiplier));\n    }\n\n    // Fallback patterns for explicit size markers (assume fp16 baseline, adjust for quant)\n    if (id.includes(\"405b\") || id.includes(\"400b\"))\n      return Math.round(405 * quantMultiplier);\n    if (id.includes(\"180b\")) return Math.round(180 * quantMultiplier);\n    if (id.includes(\"141b\") || id.includes(\"140b\"))\n      return Math.round(140 * quantMultiplier);\n    if (id.includes(\"123b\") || id.includes(\"120b\"))\n      return Math.round(123 * quantMultiplier);\n    if (id.includes(\"72b\") || id.includes(\"70b\"))\n      return Math.round(70 * quantMultiplier);\n    if (id.includes(\"67b\") || id.includes(\"65b\"))\n      return Math.round(65 * quantMultiplier);\n    if (\n      id.includes(\"35b\") ||\n      id.includes(\"34b\") ||\n      id.includes(\"32b\") ||\n      id.includes(\"30b\")\n    )\n      return Math.round(32 * quantMultiplier);\n    if (id.includes(\"27b\") || id.includes(\"26b\") || id.includes(\"22b\"))\n      return Math.round(24 * quantMultiplier);\n    if (id.includes(\"14b\") || id.includes(\"13b\") || id.includes(\"15b\"))\n      return Math.round(14 * quantMultiplier);\n    if (id.includes(\"8b\") || id.includes(\"9b\") || id.includes(\"7b\"))\n      return Math.round(8 * quantMultiplier);\n    if (id.includes(\"3b\") || id.includes(\"3.8b\"))\n      return Math.round(4 * quantMultiplier);\n    if (\n      id.includes(\"2b\") ||\n      id.includes(\"1b\") ||\n      id.includes(\"1.5b\") ||\n      id.includes(\"0.5b\")\n    )\n      return Math.round(2 * quantMultiplier);\n    return 16; // Default fallback\n  }\n\n  // Helper to estimate performance from model ID\n  function estimatePerformance(modelId: string): { ttft: number; tps: number } {\n    const id = modelId.toLowerCase();\n    if (id.includes(\"405b\") || id.includes(\"400b\"))\n      return { ttft: 8000, tps: 3 };\n    if (id.includes(\"180b\")) return { ttft: 4000, tps: 5 };\n    if (id.includes(\"141b\") || id.includes(\"140b\"))\n      return { ttft: 3500, tps: 6 };\n    if (id.includes(\"123b\") || id.includes(\"120b\"))\n      return { ttft: 3000, tps: 7 };\n    if (id.includes(\"72b\") || id.includes(\"70b\"))\n      return { ttft: 1800, tps: 12 };\n    if (id.includes(\"67b\") || id.includes(\"65b\"))\n      return { ttft: 1600, tps: 14 };\n    if (\n      id.includes(\"35b\") ||\n      id.includes(\"34b\") ||\n      id.includes(\"32b\") ||\n      id.includes(\"30b\")\n    )\n      return { ttft: 900, tps: 22 };\n    if (id.includes(\"27b\") || id.includes(\"26b\") || id.includes(\"22b\"))\n      return { ttft: 700, tps: 28 };\n    if (id.includes(\"14b\") || id.includes(\"13b\") || id.includes(\"15b\"))\n      return { ttft: 400, tps: 45 };\n    if (id.includes(\"8b\") || id.includes(\"9b\") || id.includes(\"7b\"))\n      return { ttft: 200, tps: 65 };\n    if (id.includes(\"4b\") || id.includes(\"3b\") || id.includes(\"3.8b\"))\n      return { ttft: 100, tps: 95 };\n    if (\n      id.includes(\"2b\") ||\n      id.includes(\"1b\") ||\n      id.includes(\"1.5b\") ||\n      id.includes(\"0.5b\")\n    )\n      return { ttft: 50, tps: 150 };\n    return { ttft: 300, tps: 50 };\n  }\n\n  const matchesSelectedRuntime = (runtime: InstanceMeta): boolean =>\n    selectedInstanceType === \"MlxRing\"\n      ? runtime === \"MlxRing\"\n      : runtime === \"MlxJaccl\";\n\n  // Helper to check if a model can be launched (has valid placement with >= minNodes)\n  function canModelFit(modelId: string): boolean {\n    // Find previews matching the model, sharding, and instance type\n    const matchingPreviews = previewsData.filter(\n      (p: PlacementPreview) =>\n        p.model_id === modelId &&\n        p.sharding === selectedSharding &&\n        matchesSelectedRuntime(p.instance_meta) &&\n        p.error === null &&\n        p.memory_delta_by_node !== null,\n    );\n\n    // Check if any preview has node count >= selectedMinNodes\n    return matchingPreviews.some(\n      (p: PlacementPreview) => getPreviewNodeCount(p) >= selectedMinNodes,\n    );\n  }\n\n  // Helper to get model size in GB (from megabytes)\n  function getModelSizeGB(model: {\n    id: string;\n    name?: string;\n    storage_size_megabytes?: number;\n  }): number {\n    if (model.storage_size_megabytes) {\n      return model.storage_size_megabytes / 1024;\n    }\n    return estimateMemoryGB(model.id, model.name);\n  }\n\n  // Calculate available memory in the cluster (in GB)\n  const availableMemoryGB = $derived(() => {\n    if (!data) return 0;\n    return (\n      Object.values(data.nodes).reduce((acc, n) => {\n        const total =\n          n.macmon_info?.memory?.ram_total ?? n.system_info?.memory ?? 0;\n        const used = n.macmon_info?.memory?.ram_usage ?? 0;\n        return acc + (total - used);\n      }, 0) /\n      (1024 * 1024 * 1024)\n    );\n  });\n\n  // Calculate total memory in the cluster (in GB)\n  const clusterTotalMemoryGB = $derived(() => {\n    if (!data) return 0;\n    return (\n      Object.values(data.nodes).reduce((acc, n) => {\n        const total =\n          n.macmon_info?.memory?.ram_total ?? n.system_info?.memory ?? 0;\n        return acc + total;\n      }, 0) /\n      (1024 * 1024 * 1024)\n    );\n  });\n\n  function getModelMemoryFitStatus(model: {\n    id: string;\n    name?: string;\n    storage_size_megabytes?: number;\n  }): ModelMemoryFitStatus {\n    const modelSizeGB = getModelSizeGB(model);\n    if (modelSizeGB <= availableMemoryGB()) {\n      return \"fits_now\";\n    }\n    if (modelSizeGB <= clusterTotalMemoryGB()) {\n      return \"fits_cluster_capacity\";\n    }\n    return \"too_large\";\n  }\n\n  // Check if a model has enough memory to run\n  function hasEnoughMemory(model: {\n    id: string;\n    name?: string;\n    storage_size_megabytes?: number;\n  }): boolean {\n    return getModelMemoryFitStatus(model) === \"fits_now\";\n  }\n\n  // Sorted models for dropdown - biggest first, unrunnable at the end\n  const sortedModels = $derived(() => {\n    return [...models].sort((a, b) => {\n      // First: models that have enough memory come before those that don't\n      const aCanFit = hasEnoughMemory(a);\n      const bCanFit = hasEnoughMemory(b);\n      if (aCanFit && !bCanFit) return -1;\n      if (!aCanFit && bCanFit) return 1;\n\n      // Then: sort by size (biggest first)\n      const aSize = getModelSizeGB(a);\n      const bSize = getModelSizeGB(b);\n      return bSize - aSize;\n    });\n  });\n\n  // Compute model tags (FASTEST, BIGGEST)\n  const modelTags = $derived(() => {\n    const tags: Record<string, string[]> = {};\n    if (models.length === 0) return tags;\n\n    // Find the fastest model (highest TPS)\n    let fastestId = \"\";\n    let highestTps = 0;\n\n    // Find the biggest model (most memory)\n    let biggestId = \"\";\n    let highestMemory = 0;\n\n    for (const model of models) {\n      const perf = estimatePerformance(model.id);\n      const mem = getModelSizeGB(model);\n\n      if (perf.tps > highestTps) {\n        highestTps = perf.tps;\n        fastestId = model.id;\n      }\n\n      if (mem > highestMemory) {\n        highestMemory = mem;\n        biggestId = model.id;\n      }\n    }\n\n    if (fastestId) {\n      tags[fastestId] = tags[fastestId] || [];\n      tags[fastestId].push(\"FASTEST\");\n    }\n\n    if (biggestId && biggestId !== fastestId) {\n      tags[biggestId] = tags[biggestId] || [];\n      tags[biggestId].push(\"BIGGEST\");\n    } else if (biggestId === fastestId && biggestId) {\n      // Same model is both - unlikely but handle it\n      tags[biggestId].push(\"BIGGEST\");\n    }\n\n    return tags;\n  });\n\n  onMount(async () => {\n    mounted = true;\n    fetchModels();\n    fetch(\"/node_id\")\n      .then((r) => (r.ok ? r.json() : null))\n      .then((id) => {\n        if (id) localNodeId = id;\n      })\n      .catch(() => {});\n\n    // Handle reset-onboarding query parameter (triggered from native Settings)\n    const params = new URLSearchParams(window.location.search);\n    if (params.has(\"reset-onboarding\")) {\n      localStorage.removeItem(ONBOARDING_COMPLETE_KEY);\n      window.history.replaceState({}, \"\", window.location.pathname);\n      onboardingStep = 1;\n      return;\n    }\n\n    // Check server-side onboarding state (persisted in ~/.exo)\n    try {\n      const res = await fetch(\"/onboarding\");\n      if (res.ok) {\n        const data = await res.json();\n        if (!data.completed) {\n          onboardingStep = 1;\n        }\n        return;\n      }\n    } catch {\n      // Server unreachable — fall through to localStorage\n    }\n\n    // Fallback: check localStorage\n    if (!localStorage.getItem(ONBOARDING_COMPLETE_KEY)) {\n      onboardingStep = 1;\n    }\n  });\n\n  async function fetchModels() {\n    try {\n      const response = await fetch(\"/models\");\n      if (response.ok) {\n        const data = await response.json();\n        // API returns { data: [{ id, name }] } format\n        models = data.data || [];\n        // Restore last launch defaults if available\n        const currentNodeCount = topologyData()\n          ? Object.keys(topologyData()!.nodes).length\n          : 1;\n        applyLaunchDefaults(models, currentNodeCount);\n      }\n    } catch (error) {\n      console.error(\"Failed to fetch models:\", error);\n    }\n  }\n\n  async function addModelFromPicker(modelId: string) {\n    const response = await fetch(\"/models/add\", {\n      method: \"POST\",\n      headers: { \"Content-Type\": \"application/json\" },\n      body: JSON.stringify({ model_id: modelId }),\n    });\n\n    if (!response.ok) {\n      let message = `Failed to add model (${response.status}: ${response.statusText})`;\n      try {\n        const err = await response.json();\n        if (err.detail) message = err.detail;\n      } catch {\n        // use default message\n      }\n      throw new Error(message);\n    }\n\n    await fetchModels();\n  }\n\n  async function deleteCustomModel(modelId: string) {\n    try {\n      const response = await fetch(\n        `/models/custom/${encodeURIComponent(modelId)}`,\n        { method: \"DELETE\" },\n      );\n      if (response.ok) {\n        await fetchModels();\n      }\n    } catch {\n      console.error(\"Failed to delete custom model\");\n    }\n  }\n\n  function handleModelPickerSelect(modelId: string) {\n    selectPreviewModel(modelId);\n    setSelectedChatModel(modelId);\n    saveLaunchDefaults();\n    isModelPickerOpen = false;\n  }\n\n  async function launchInstance(\n    modelId: string,\n    specificPreview?: PlacementPreview | null,\n  ) {\n    if (!modelId || launchingModelId) return;\n\n    launchingModelId = modelId;\n\n    try {\n      // Use the specific preview if provided, otherwise fall back to filtered preview\n      const preview = specificPreview ?? filteredPreview();\n\n      let response: Response;\n      if (preview?.instance) {\n        // Launch with pre-computed placement from preview\n        response = await fetch(\"/instance\", {\n          method: \"POST\",\n          headers: { \"Content-Type\": \"application/json\" },\n          body: JSON.stringify({ instance: preview.instance }),\n        });\n      } else {\n        // No preview available — use place_instance to let server decide placement\n        response = await fetch(\"/place_instance\", {\n          method: \"POST\",\n          headers: { \"Content-Type\": \"application/json\" },\n          body: JSON.stringify({\n            model_id: modelId,\n            sharding: selectedSharding,\n            instance_meta: selectedInstanceType,\n            min_nodes: 1,\n          }),\n        });\n      }\n\n      if (!response.ok) {\n        const errorText = await response.text();\n        console.error(\"Failed to launch instance:\", errorText);\n        addToast({\n          type: \"error\",\n          message: `Failed to launch model: ${errorText}`,\n        });\n      } else {\n        addToast({ type: \"info\", message: `Launching model...` });\n        // Always auto-select the newly launched model so the user chats to what they just launched\n        setSelectedChatModel(modelId);\n\n        // Record the launch in recent models history\n        recordRecentLaunch(modelId);\n\n        // Scroll to the bottom of instances container to show the new instance\n        // Use multiple attempts to ensure DOM has updated with the new instance\n        const scrollToBottom = () => {\n          if (instancesContainerRef) {\n            instancesContainerRef.scrollTo({\n              top: instancesContainerRef.scrollHeight,\n              behavior: \"smooth\",\n            });\n          }\n        };\n        setTimeout(scrollToBottom, 200);\n        setTimeout(scrollToBottom, 500);\n        setTimeout(scrollToBottom, 1000);\n      }\n    } catch (error) {\n      console.error(\"Error launching instance:\", error);\n      addToast({\n        type: \"error\",\n        message: \"Failed to launch model. Check console for details.\",\n      });\n    } finally {\n      launchingModelId = null;\n    }\n  }\n\n  // Helper to extract model ID from download shard metadata\n  function extractModelIdFromDownload(\n    downloadPayload: Record<string, unknown>,\n  ): string | null {\n    const shardMetadata =\n      downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;\n    if (!shardMetadata || typeof shardMetadata !== \"object\") return null;\n\n    // Shard metadata is a tagged union: { PipelineShardMetadata: {...} } or { TensorShardMetadata: {...} }\n    const shardObj = shardMetadata as Record<string, unknown>;\n    const shardKeys = Object.keys(shardObj);\n    if (shardKeys.length !== 1) return null;\n\n    const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;\n    if (!shardData) return null;\n\n    // Model meta is nested: shard.model_card.model_id\n    const modelMeta = shardData.model_card ?? shardData.modelCard;\n    if (!modelMeta || typeof modelMeta !== \"object\") return null;\n\n    const meta = modelMeta as Record<string, unknown>;\n    return (meta.model_id as string) ?? (meta.modelId as string) ?? null;\n  }\n\n  // Helper to parse download progress from payload\n  function parseDownloadProgress(\n    payload: Record<string, unknown>,\n  ): DownloadProgress | null {\n    const progress = payload.download_progress ?? payload.downloadProgress;\n    if (!progress || typeof progress !== \"object\") return null;\n\n    const prog = progress as Record<string, unknown>;\n    const totalBytes = getBytes(prog.total);\n    const downloadedBytes = getBytes(prog.downloaded);\n    const speed = (prog.speed as number) ?? 0;\n    const completedFiles =\n      (prog.completed_files as number) ?? (prog.completedFiles as number) ?? 0;\n    const totalFiles =\n      (prog.total_files as number) ?? (prog.totalFiles as number) ?? 0;\n    const etaMs = (prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;\n\n    const files: DownloadProgress[\"files\"] = [];\n    const filesObj = (prog.files ?? {}) as Record<string, unknown>;\n    for (const [fileName, fileData] of Object.entries(filesObj)) {\n      if (!fileData || typeof fileData !== \"object\") continue;\n      const fd = fileData as Record<string, unknown>;\n      const fTotal = getBytes(fd.total);\n      const fDownloaded = getBytes(fd.downloaded);\n      files.push({\n        name: fileName,\n        totalBytes: fTotal,\n        downloadedBytes: fDownloaded,\n        speed: (fd.speed as number) ?? 0,\n        etaMs: (fd.eta_ms as number) ?? (fd.etaMs as number) ?? 0,\n        percentage: fTotal > 0 ? (fDownloaded / fTotal) * 100 : 0,\n      });\n    }\n\n    return {\n      totalBytes,\n      downloadedBytes,\n      speed,\n      etaMs:\n        etaMs ||\n        (speed > 0 ? ((totalBytes - downloadedBytes) / speed) * 1000 : 0),\n      percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,\n      completedFiles,\n      totalFiles,\n      files,\n    };\n  }\n\n  // Helper to get download status for a model (checks all downloads for matching model ID)\n  type NodeDownloadStatus = {\n    nodeId: string;\n    nodeName: string;\n    status: \"completed\" | \"partial\" | \"pending\" | \"downloading\";\n    percentage: number;\n    progress: DownloadProgress | null;\n  };\n\n  // Shared helper: collect per-node download status for a model across a set of nodes.\n  // Handles deduplication, entry parsing, and aggregation in one place.\n  function collectDownloadStatus(\n    modelId: string,\n    nodeIds?: string[],\n  ): {\n    isDownloading: boolean;\n    progress: DownloadProgress | null;\n    perNode: NodeDownloadStatus[];\n    failedError: string | null;\n  } {\n    const empty = {\n      isDownloading: false,\n      progress: null,\n      perNode: [] as NodeDownloadStatus[],\n      failedError: null,\n    };\n\n    if (!downloadsData || Object.keys(downloadsData).length === 0) {\n      return empty;\n    }\n\n    // Deduplicate by nodeId — a node can have multiple entries for the same model\n    // (e.g. PipelineShardMetadata + TensorShardMetadata). Keep the last entry,\n    // which is the most recently applied event.\n    const perNodeMap = new Map<string, NodeDownloadStatus>();\n\n    const nodeIdSet = nodeIds ? new Set(nodeIds) : null;\n    for (const [nodeId, nodeDownloads] of Object.entries(downloadsData)) {\n      if (nodeIdSet && !nodeIdSet.has(nodeId)) continue;\n      if (!Array.isArray(nodeDownloads)) continue;\n\n      for (const downloadWrapped of nodeDownloads) {\n        if (!downloadWrapped || typeof downloadWrapped !== \"object\") continue;\n\n        const keys = Object.keys(downloadWrapped as Record<string, unknown>);\n        if (keys.length !== 1) continue;\n\n        const downloadKind = keys[0];\n        const downloadPayload = (downloadWrapped as Record<string, unknown>)[\n          downloadKind\n        ] as Record<string, unknown>;\n        if (!downloadPayload) continue;\n\n        const downloadModelId = extractModelIdFromDownload(downloadPayload);\n        if (!downloadModelId || downloadModelId !== modelId) continue;\n\n        // DownloadFailed — return with any data collected so far\n        if (downloadKind === \"DownloadFailed\") {\n          return {\n            isDownloading: false,\n            progress: null,\n            perNode: Array.from(perNodeMap.values()),\n            failedError:\n              (downloadPayload.errorMessage as string) ||\n              (downloadPayload.error_message as string) ||\n              \"Download failed\",\n          };\n        }\n\n        if (\n          downloadKind !== \"DownloadOngoing\" &&\n          downloadKind !== \"DownloadPending\" &&\n          downloadKind !== \"DownloadCompleted\"\n        )\n          continue;\n\n        const nodeName =\n          data?.nodes?.[nodeId]?.friendly_name ?? nodeId.slice(0, 8);\n\n        if (downloadKind === \"DownloadCompleted\") {\n          perNodeMap.set(nodeId, {\n            nodeId,\n            nodeName,\n            status: \"completed\",\n            percentage: 100,\n            progress: null,\n          });\n          continue;\n        }\n\n        if (downloadKind === \"DownloadPending\") {\n          const pendingDownloaded = getBytes(\n            downloadPayload.downloaded ??\n              downloadPayload.downloaded_bytes ??\n              downloadPayload.downloadedBytes,\n          );\n          const pendingTotal = getBytes(\n            downloadPayload.total ??\n              downloadPayload.total_bytes ??\n              downloadPayload.totalBytes,\n          );\n          if (pendingDownloaded <= 0 && pendingTotal <= 0) continue;\n          const pct =\n            pendingTotal > 0 ? (pendingDownloaded / pendingTotal) * 100 : 0;\n          perNodeMap.set(nodeId, {\n            nodeId,\n            nodeName,\n            status: pendingDownloaded > 0 ? \"partial\" : \"pending\",\n            percentage: pct,\n            progress: null,\n          });\n          continue;\n        }\n\n        // DownloadOngoing\n        const progress = parseDownloadProgress(downloadPayload);\n        if (\n          !progress ||\n          (progress.downloadedBytes <= 0 && progress.totalBytes <= 0)\n        )\n          continue;\n\n        perNodeMap.set(nodeId, {\n          nodeId,\n          nodeName,\n          status: \"downloading\",\n          percentage: progress.percentage,\n          progress,\n        });\n      }\n    }\n\n    // Aggregate from deduplicated per-node entries\n    const perNode = Array.from(perNodeMap.values());\n    let totalBytes = 0;\n    let downloadedBytes = 0;\n    let totalSpeed = 0;\n    let completedFiles = 0;\n    let totalFiles = 0;\n    let isDownloading = false;\n    const allFiles: DownloadProgress[\"files\"] = [];\n\n    for (const node of perNode) {\n      if (node.status === \"downloading\" && node.progress) {\n        isDownloading = true;\n        totalBytes += node.progress.totalBytes;\n        downloadedBytes += node.progress.downloadedBytes;\n        totalSpeed += node.progress.speed;\n        completedFiles += node.progress.completedFiles;\n        totalFiles += node.progress.totalFiles;\n        allFiles.push(...node.progress.files);\n      }\n    }\n\n    if (!isDownloading) {\n      return {\n        isDownloading: false,\n        progress: null,\n        perNode,\n        failedError: null,\n      };\n    }\n\n    const remainingBytes = totalBytes - downloadedBytes;\n    const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;\n\n    return {\n      isDownloading: true,\n      progress: {\n        totalBytes,\n        downloadedBytes,\n        speed: totalSpeed,\n        etaMs,\n        percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,\n        completedFiles,\n        totalFiles,\n        files: allFiles,\n      },\n      perNode,\n      failedError: null,\n    };\n  }\n\n  function getModelDownloadStatus(\n    modelId: string,\n    nodeIds?: string[],\n  ): {\n    isDownloading: boolean;\n    progress: DownloadProgress | null;\n    perNode: NodeDownloadStatus[];\n  } {\n    return collectDownloadStatus(modelId, nodeIds);\n  }\n\n  // Helper to get download status for an instance\n  function getInstanceDownloadStatus(\n    instanceId: string,\n    instanceWrapped: unknown,\n  ): {\n    isDownloading: boolean;\n    isFailed: boolean;\n    errorMessage: string | null;\n    progress: DownloadProgress | null;\n    statusText: string;\n    perNode: NodeDownloadStatus[];\n  } {\n    // Unwrap the instance to get shard assignments\n    const [instanceTag, instance] = getTagged(instanceWrapped);\n    if (!instance || typeof instance !== \"object\") {\n      return {\n        isDownloading: false,\n        isFailed: false,\n        errorMessage: null,\n        progress: null,\n        statusText: \"PREPARING\",\n        perNode: [],\n      };\n    }\n\n    const inst = instance as {\n      shardAssignments?: {\n        nodeToRunner?: Record<string, string>;\n        runnerToShard?: Record<string, unknown>;\n        modelId?: string;\n      };\n    };\n    const instanceModelId = inst.shardAssignments?.modelId;\n\n    if (!instanceModelId) {\n      const statusInfo = deriveInstanceStatus(instanceWrapped);\n      return {\n        isDownloading: false,\n        isFailed: statusInfo.statusText === \"FAILED\",\n        errorMessage: null,\n        progress: null,\n        statusText: statusInfo.statusText,\n        perNode: [],\n      };\n    }\n\n    // Get node IDs assigned to this instance\n    const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};\n    const runnerToShard = inst.shardAssignments?.runnerToShard || {};\n    const runnerToNode: Record<string, string> = {};\n    for (const [nodeId, runnerId] of Object.entries(nodeToRunner)) {\n      runnerToNode[runnerId] = nodeId;\n    }\n    const instanceNodeIds = Object.keys(runnerToShard)\n      .map((runnerId) => runnerToNode[runnerId])\n      .filter(Boolean);\n\n    const result = collectDownloadStatus(instanceModelId, instanceNodeIds);\n\n    if (result.failedError) {\n      return {\n        isDownloading: false,\n        isFailed: true,\n        errorMessage: result.failedError,\n        progress: null,\n        statusText: \"FAILED\",\n        perNode: [],\n      };\n    }\n\n    if (!result.isDownloading) {\n      const statusInfo = deriveInstanceStatus(instanceWrapped);\n      return {\n        isDownloading: false,\n        isFailed: statusInfo.statusText === \"FAILED\",\n        errorMessage: null,\n        progress: null,\n        statusText: statusInfo.statusText,\n        perNode: result.perNode,\n      };\n    }\n\n    return {\n      isDownloading: true,\n      isFailed: false,\n      errorMessage: null,\n      progress: result.progress,\n      statusText: \"DOWNLOADING\",\n      perNode: result.perNode,\n    };\n  }\n\n  // Derive instance status from runners\n  // Get color class for a status\n  function getStatusColor(statusText: string): string {\n    switch (statusText) {\n      case \"FAILED\":\n        return \"text-red-400\";\n      case \"SHUTDOWN\":\n        return \"text-gray-400\";\n      case \"DOWNLOADING\":\n        return \"text-blue-400\";\n      case \"LOADING\":\n      case \"WARMING UP\":\n      case \"WAITING\":\n      case \"INITIALIZING\":\n        return \"text-yellow-400\";\n      case \"RUNNING\":\n        return \"text-teal-400\";\n      case \"READY\":\n      case \"LOADED\":\n        return \"text-green-400\";\n      default:\n        return \"text-exo-light-gray\";\n    }\n  }\n\n  function deriveInstanceStatus(instanceWrapped: unknown): {\n    statusText: string;\n    statusClass: string;\n    layersLoaded?: number;\n    totalLayers?: number;\n  } {\n    const [instanceTag, instance] = getTagged(instanceWrapped);\n    if (!instance || typeof instance !== \"object\") {\n      return { statusText: \"PREPARING\", statusClass: \"inactive\" };\n    }\n\n    const inst = instance as {\n      shardAssignments?: { runnerToShard?: Record<string, unknown> };\n    };\n    const runnerIds = Object.keys(inst.shardAssignments?.runnerToShard || {});\n\n    const statuses = runnerIds\n      .map((rid) => {\n        const r = runnersData[rid];\n        if (!r) return null;\n        const [kind] = getTagged(r);\n        const statusMap: Record<string, string> = {\n          RunnerWaitingForInitialization: \"WaitingForInitialization\",\n          RunnerInitializingBackend: \"InitializingBackend\",\n          RunnerWaitingForModel: \"WaitingForModel\",\n          RunnerLoading: \"Loading\",\n          RunnerLoaded: \"Loaded\",\n          RunnerWarmingUp: \"WarmingUp\",\n          RunnerReady: \"Ready\",\n          RunnerRunning: \"Running\",\n          RunnerShutdown: \"Shutdown\",\n          RunnerFailed: \"Failed\",\n        };\n        return kind ? statusMap[kind] || null : null;\n      })\n      .filter((s): s is string => s !== null);\n\n    const has = (s: string) => statuses.includes(s);\n\n    if (statuses.length === 0)\n      return { statusText: \"PREPARING\", statusClass: \"inactive\" };\n    if (has(\"Failed\")) return { statusText: \"FAILED\", statusClass: \"failed\" };\n    if (has(\"Shutdown\"))\n      return { statusText: \"SHUTDOWN\", statusClass: \"inactive\" };\n    if (has(\"Loading\")) {\n      // Tensor parallel: each runner loads all layers — use max/min (bottleneck)\n      // Pipeline parallel: each runner loads a disjoint slice — use sum\n      const isTensor = instanceTag === \"MlxJacclInstance\";\n      let layersLoaded = isTensor ? Infinity : 0;\n      let totalLayers = 0;\n      for (const rid of runnerIds) {\n        const r = runnersData[rid];\n        if (!r) continue;\n        const [kind, payload] = getTagged(r);\n        if (\n          kind === \"RunnerLoading\" &&\n          payload &&\n          typeof payload === \"object\"\n        ) {\n          const p = payload as { layersLoaded?: number; totalLayers?: number };\n          if (isTensor) {\n            layersLoaded = Math.min(layersLoaded, p.layersLoaded ?? 0);\n            totalLayers = Math.max(totalLayers, p.totalLayers ?? 0);\n          } else {\n            layersLoaded += p.layersLoaded ?? 0;\n            totalLayers += p.totalLayers ?? 0;\n          }\n        }\n      }\n      if (isTensor && layersLoaded === Infinity) layersLoaded = 0;\n      return {\n        statusText: \"LOADING\",\n        statusClass: \"starting\",\n        layersLoaded,\n        totalLayers,\n      };\n    }\n    if (has(\"WarmingUp\"))\n      return { statusText: \"WARMING UP\", statusClass: \"starting\" };\n    if (has(\"Running\"))\n      return { statusText: \"RUNNING\", statusClass: \"running\" };\n    if (has(\"Ready\")) return { statusText: \"READY\", statusClass: \"loaded\" };\n    if (has(\"Loaded\")) return { statusText: \"LOADED\", statusClass: \"loaded\" };\n    if (has(\"WaitingForModel\"))\n      return { statusText: \"WAITING\", statusClass: \"starting\" };\n    if (has(\"InitializingBackend\"))\n      return { statusText: \"INITIALIZING\", statusClass: \"starting\" };\n    if (has(\"WaitingForInitialization\"))\n      return { statusText: \"INITIALIZING\", statusClass: \"starting\" };\n\n    return { statusText: \"RUNNING\", statusClass: \"active\" };\n  }\n\n  function getBytes(value: unknown): number {\n    if (typeof value === \"number\") return value;\n    if (value && typeof value === \"object\") {\n      const v = value as Record<string, unknown>;\n      if (typeof v.inBytes === \"number\") return v.inBytes;\n    }\n    return 0;\n  }\n\n  async function deleteInstance(instanceId: string) {\n    if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;\n\n    // Get the model ID of the instance being deleted before we delete it\n    const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);\n    const wasSelected = selectedChatModel() === deletedInstanceModelId;\n\n    try {\n      const response = await fetch(`/instance/${instanceId}`, {\n        method: \"DELETE\",\n        headers: { \"Content-Type\": \"application/json\" },\n      });\n\n      if (!response.ok) {\n        console.error(\"Failed to delete instance:\", response.status);\n        addToast({ type: \"error\", message: \"Failed to delete instance\" });\n      } else if (wasSelected) {\n        // If we deleted the currently selected model, switch to another available model\n        // Find another instance that isn't the one we just deleted\n        const remainingInstances = Object.entries(instanceData).filter(\n          ([id]) => id !== instanceId,\n        );\n        if (remainingInstances.length > 0) {\n          // Select the last instance (most recently added, since objects preserve insertion order)\n          const [, lastInstance] =\n            remainingInstances[remainingInstances.length - 1];\n          const newModelId = getInstanceModelId(lastInstance);\n          if (\n            newModelId &&\n            newModelId !== \"Unknown\" &&\n            newModelId !== \"Unknown Model\"\n          ) {\n            setSelectedChatModel(newModelId);\n          } else {\n            // Clear selection if no valid model found\n            setSelectedChatModel(\"\");\n          }\n        } else {\n          // No more instances, clear the selection\n          setSelectedChatModel(\"\");\n        }\n      }\n    } catch (error) {\n      console.error(\"Error deleting instance:\", error);\n    }\n  }\n\n  // Helper to unwrap tagged unions like { MlxRingInstance: {...} }\n  function getTagged(obj: unknown): [string | null, unknown] {\n    if (!obj || typeof obj !== \"object\") return [null, null];\n    const keys = Object.keys(obj as Record<string, unknown>);\n    if (keys.length === 1) {\n      return [keys[0], (obj as Record<string, unknown>)[keys[0]]];\n    }\n    return [null, null];\n  }\n\n  // Get model ID from an instance\n  function getInstanceModelId(instanceWrapped: unknown): string {\n    const [, instance] = getTagged(instanceWrapped);\n    if (!instance || typeof instance !== \"object\") return \"Unknown\";\n    const inst = instance as { shardAssignments?: { modelId?: string } };\n    return inst.shardAssignments?.modelId || \"Unknown Model\";\n  }\n\n  // Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names\n  function getInstanceInfo(instanceWrapped: unknown): {\n    instanceType: string;\n    sharding: string;\n    nodeNames: string[];\n    nodeIds: string[];\n    nodeCount: number;\n  } {\n    const [instanceTag, instance] = getTagged(instanceWrapped);\n    if (!instance || typeof instance !== \"object\") {\n      return {\n        instanceType: \"Unknown\",\n        sharding: \"Unknown\",\n        nodeNames: [],\n        nodeIds: [],\n        nodeCount: 0,\n      };\n    }\n\n    // Instance type from tag\n    let instanceType = \"Unknown\";\n    if (instanceTag === \"MlxRingInstance\") instanceType = \"MLX Ring\";\n    else if (instanceTag === \"MlxJacclInstance\") instanceType = \"MLX RDMA\";\n\n    const inst = instance as {\n      shardAssignments?: {\n        nodeToRunner?: Record<string, string>;\n        runnerToShard?: Record<string, unknown>;\n      };\n    };\n\n    // Sharding strategy from first shard\n    let sharding = \"Unknown\";\n    const runnerToShard = inst.shardAssignments?.runnerToShard || {};\n    const firstShardWrapped = Object.values(runnerToShard)[0];\n    if (firstShardWrapped) {\n      const [shardTag] = getTagged(firstShardWrapped);\n      if (shardTag === \"PipelineShardMetadata\") sharding = \"Pipeline\";\n      else if (shardTag === \"TensorShardMetadata\") sharding = \"Tensor\";\n      else if (shardTag === \"PrefillDecodeShardMetadata\")\n        sharding = \"Prefill/Decode\";\n    }\n\n    // Node names from topology\n    const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};\n    const nodeIds = Object.keys(nodeToRunner);\n    const nodeNames = nodeIds.map((nodeId) => {\n      const node = data?.nodes?.[nodeId];\n      return node?.friendly_name || nodeId.slice(0, 8);\n    });\n\n    return {\n      instanceType,\n      sharding,\n      nodeNames,\n      nodeIds,\n      nodeCount: nodeIds.length,\n    };\n  }\n\n  // Compute instance statuses by modelId for the model picker\n  const modelInstanceStatuses = $derived.by(() => {\n    const result: Record<string, { status: string; statusClass: string }> = {};\n    for (const [id, inst] of Object.entries(instanceData)) {\n      const modelId = getInstanceModelId(inst);\n      if (!modelId || modelId === \"Unknown\" || modelId === \"Unknown Model\")\n        continue;\n      const dlStatus = getInstanceDownloadStatus(id, inst);\n      const statusText = dlStatus.statusText;\n      let statusClass = \"inactive\";\n      if (\n        statusText === \"READY\" ||\n        statusText === \"RUNNING\" ||\n        statusText === \"LOADED\"\n      ) {\n        statusClass = \"ready\";\n      } else if (statusText === \"DOWNLOADING\") {\n        statusClass = \"downloading\";\n      } else if (statusText === \"LOADING\" || statusText === \"WARMING UP\") {\n        statusClass = \"loading\";\n      }\n      // Keep the best status per modelId (ready > loading > downloading > other)\n      const existing = result[modelId];\n      if (existing) {\n        const rank = (c: string) =>\n          c === \"ready\" ? 3 : c === \"loading\" ? 2 : c === \"downloading\" ? 1 : 0;\n        if (rank(statusClass) <= rank(existing.statusClass)) continue;\n      }\n      result[modelId] = { status: statusText, statusClass };\n    }\n    return result;\n  });\n\n  function formatLastUpdate(): string {\n    if (!update) return \"ACQUIRING...\";\n    const seconds = Math.floor((Date.now() - update) / 1000);\n    if (seconds < 5) return \"LIVE\";\n    return `${seconds}s AGO`;\n  }\n\n  function formatBytes(bytes: number, decimals = 2): string {\n    if (!bytes || bytes === 0) return \"0 B\";\n    const k = 1024;\n    const sizes = [\"B\", \"KB\", \"MB\", \"GB\", \"TB\"];\n    const i = Math.floor(Math.log(bytes) / Math.log(k));\n    return (\n      parseFloat((bytes / Math.pow(k, i)).toFixed(decimals)) + \" \" + sizes[i]\n    );\n  }\n\n  function formatSpeed(bps: number): string {\n    if (!bps || bps <= 0) return \"0 B/s\";\n    return formatBytes(bps, 1) + \"/s\";\n  }\n\n  function getNodeLabel(nodeId: string): string {\n    const node = data?.nodes?.[nodeId];\n    return node?.friendly_name || nodeId.slice(0, 8);\n  }\n\n  function getInterfaceLabel(\n    nodeId: string,\n    ip?: string,\n  ): { label: string; missing: boolean } {\n    if (!ip) return { label: \"?\", missing: true };\n    const node = data?.nodes?.[nodeId];\n    if (!node) return { label: \"?\", missing: true };\n\n    // Prefer explicit network_interfaces from NodePerformanceProfile\n    const matchFromInterfaces = node.network_interfaces?.find((iface) =>\n      (iface.addresses || []).some((addr) => addr === ip),\n    );\n    if (matchFromInterfaces?.name) {\n      return {\n        label: `${matchFromInterfaces.name} on ${getNodeLabel(nodeId)}`,\n        missing: false,\n      };\n    }\n\n    // Fallback to derived ip_to_interface map\n    const mapped = node.ip_to_interface?.[ip];\n    if (mapped && mapped.trim().length > 0) {\n      return { label: `${mapped} on ${getNodeLabel(nodeId)}`, missing: false };\n    }\n\n    return { label: \"?\", missing: true };\n  }\n\n  function getOrderedRunnerNodes(\n    instance: Record<string, unknown>,\n    shardType: \"Pipeline\" | \"Tensor\",\n  ) {\n    const runnerToShard =\n      (\n        instance.shardAssignments as\n          | { runnerToShard?: Record<string, unknown> }\n          | undefined\n      )?.runnerToShard || {};\n    const nodeToRunner =\n      (\n        instance.shardAssignments as\n          | { nodeToRunner?: Record<string, string> }\n          | undefined\n      )?.nodeToRunner || {};\n    const runnerEntries = Object.entries(runnerToShard).map(\n      ([runnerId, shardWrapped]) => {\n        const [tag, shard] = getTagged(shardWrapped);\n        const meta = shard as\n          | {\n              modelMeta?: {\n                worldSize?: number;\n                nLayers?: number;\n                deviceRank?: number;\n              };\n            }\n          | undefined;\n        const deviceRank = meta?.modelMeta?.deviceRank ?? 0;\n        return { runnerId, tag, deviceRank };\n      },\n    );\n\n    const ordered = runnerEntries\n      .filter((r) =>\n        shardType === \"Pipeline\"\n          ? r.tag === \"PipelineShardMetadata\"\n          : r.tag === \"TensorShardMetadata\",\n      )\n      .sort((a, b) => a.deviceRank - b.deviceRank)\n      .map((r, idx) => {\n        const nodeId = Object.entries(nodeToRunner).find(\n          ([, rid]) => rid === r.runnerId,\n        )?.[0];\n        return { nodeId, runnerId: r.runnerId, order: idx };\n      })\n      .filter((item) => item.nodeId);\n\n    return ordered as Array<{\n      nodeId: string;\n      runnerId: string;\n      order: number;\n    }>;\n  }\n\n  function pickHost(\n    hosts?: Array<{ ip: string; port: number }>,\n  ): { ip: string; port: number } | null {\n    if (!hosts || hosts.length === 0) return null;\n    const scored = hosts\n      .filter((h) => h.ip && h.ip !== \"0.0.0.0\" && h.port && h.port > 0)\n      .map((h) => {\n        const ip = h.ip;\n        const score =\n          ip.startsWith(\"10.\") ||\n          ip.startsWith(\"172.\") ||\n          ip.startsWith(\"192.168\")\n            ? 3\n            : ip.startsWith(\"169.254\")\n              ? 2\n              : 1;\n        return { host: h, score };\n      });\n    if (scored.length === 0) return null;\n    scored.sort((a, b) => b.score - a.score);\n    return scored[0].host;\n  }\n\n  function getInstanceConnections(instanceWrapped: unknown): Array<{\n    from: string;\n    to: string;\n    ip: string;\n    ifaceLabel: string;\n    missingIface: boolean;\n  }> {\n    const [instanceTag, instance] = getTagged(instanceWrapped);\n    if (!instance || typeof instance !== \"object\") return [];\n\n    // Jaccl (RDMA) – show RDMA interfaces from ibvDevices\n    if (instanceTag === \"MlxJacclInstance\") {\n      const ordered = getOrderedRunnerNodes(\n        instance as Record<string, unknown>,\n        \"Tensor\",\n      );\n      const ibvDevices =\n        (instance as { ibvDevices?: Array<Array<string | null>> }).ibvDevices ||\n        [];\n      const rows: Array<{\n        from: string;\n        to: string;\n        ip: string;\n        ifaceLabel: string;\n        missingIface: boolean;\n      }> = [];\n\n      for (let i = 0; i < ordered.length; i++) {\n        for (let j = i + 1; j < ordered.length; j++) {\n          const iface = ibvDevices[i]?.[j] ?? ibvDevices[j]?.[i] ?? null;\n          if (!iface) continue;\n          const fromId = ordered[i].nodeId;\n          const toId = ordered[j].nodeId;\n          rows.push({\n            from: getNodeLabel(fromId),\n            to: getNodeLabel(toId),\n            ip: iface,\n            ifaceLabel: `RDMA ${iface}`,\n            missingIface: false,\n          });\n        }\n      }\n      return rows;\n    }\n\n    // Ring – derive ring order from pipeline shard ranks and pick host IPs from hostsByNode\n    if (instanceTag === \"MlxRingInstance\") {\n      const ordered = getOrderedRunnerNodes(\n        instance as Record<string, unknown>,\n        \"Pipeline\",\n      );\n      const hostsByNode =\n        (\n          instance as {\n            hostsByNode?: Record<string, Array<{ ip: string; port: number }>>;\n          }\n        ).hostsByNode || {};\n      const rows: Array<{\n        from: string;\n        to: string;\n        ip: string;\n        ifaceLabel: string;\n        missingIface: boolean;\n      }> = [];\n      if (ordered.length === 0) return rows;\n\n      for (let idx = 0; idx < ordered.length; idx++) {\n        const current = ordered[idx];\n        const next = ordered[(idx + 1) % ordered.length];\n        const host = pickHost(hostsByNode[next.nodeId]);\n        const ip = host ? `${host.ip}:${host.port}` : \"?\";\n        const ifaceInfo = host\n          ? getInterfaceLabel(next.nodeId, host.ip)\n          : { label: \"?\", missing: true };\n        rows.push({\n          from: getNodeLabel(current.nodeId),\n          to: getNodeLabel(next.nodeId),\n          ip,\n          ifaceLabel: ifaceInfo.label,\n          missingIface: ifaceInfo.missing,\n        });\n      }\n      return rows;\n    }\n\n    return [];\n  }\n\n  function formatEta(ms: number): string {\n    if (!ms || ms <= 0) return \"--\";\n    const totalSeconds = Math.round(ms / 1000);\n    const s = totalSeconds % 60;\n    const m = Math.floor(totalSeconds / 60) % 60;\n    const h = Math.floor(totalSeconds / 3600);\n    if (h > 0) return `${h}h ${m}m`;\n    if (m > 0) return `${m}m ${s}s`;\n    return `${s}s`;\n  }\n\n  function handleNewChat() {\n    chatLaunchState = \"idle\";\n    pendingChatModelId = null;\n    selectedChatCategory = null;\n    pendingAutoMessage = null;\n    userForcedIdle = true;\n    setSelectedChatModel(\"\");\n    createConversation();\n  }\n\n  function handleGoHome() {\n    chatLaunchState = \"idle\";\n    pendingChatModelId = null;\n    selectedChatCategory = null;\n    pendingAutoMessage = null;\n    userForcedIdle = true;\n    // Restore chat model from the sidebar preview selection so both selectors stay in sync\n    setSelectedChatModel(selectedModelId ?? \"\");\n    clearChat();\n  }\n\n  // Slider drag handlers\n  function handleSliderDrag(clientX: number) {\n    if (!sliderTrackElement || availableMinNodes <= 1) return;\n\n    const rect = sliderTrackElement.getBoundingClientRect();\n    const percentage = Math.max(\n      0,\n      Math.min(1, (clientX - rect.left) / rect.width),\n    );\n    const rawValue = Math.round(percentage * (availableMinNodes - 1)) + 1;\n    const clampedValue = Math.max(1, Math.min(availableMinNodes, rawValue));\n\n    // Find nearest valid value\n    const validCounts = validMinNodeCounts();\n    if (validCounts.has(clampedValue)) {\n      selectedMinNodes = clampedValue;\n    } else {\n      // Find nearest valid value\n      let nearest = clampedValue;\n      let minDist = Infinity;\n      for (const v of validCounts) {\n        const dist = Math.abs(v - clampedValue);\n        if (dist < minDist) {\n          minDist = dist;\n          nearest = v;\n        }\n      }\n      if (validCounts.size > 0) {\n        selectedMinNodes = nearest;\n      }\n    }\n  }\n\n  function handleSliderMouseDown(event: MouseEvent) {\n    isDraggingSlider = true;\n    handleSliderDrag(event.clientX);\n  }\n\n  function handleSliderMouseMove(event: MouseEvent) {\n    if (isDraggingSlider) {\n      handleSliderDrag(event.clientX);\n    }\n  }\n\n  function handleSliderMouseUp() {\n    isDraggingSlider = false;\n    saveLaunchDefaults();\n  }\n\n  // Handle touch events for mobile\n  function handleSliderTouchStart(event: TouchEvent) {\n    isDraggingSlider = true;\n    if (event.touches.length > 0) {\n      handleSliderDrag(event.touches[0].clientX);\n    }\n  }\n\n  function handleSliderTouchMove(event: TouchEvent) {\n    if (isDraggingSlider && event.touches.length > 0) {\n      event.preventDefault();\n      handleSliderDrag(event.touches[0].clientX);\n    }\n  }\n\n  function handleSliderTouchEnd() {\n    isDraggingSlider = false;\n    saveLaunchDefaults();\n  }\n\n  const nodeCount = $derived(data ? Object.keys(data.nodes).length : 0);\n  const instanceCount = $derived(Object.keys(instanceData).length);\n\n  // ── Instance status transition toasts ──\n  // Track previous statuses so we can detect meaningful transitions and fire toasts.\n  let previousInstanceStatuses: Record<string, string> = {};\n\n  $effect(() => {\n    const currentStatuses: Record<string, string> = {};\n    for (const [id, inst] of Object.entries(instanceData)) {\n      const dlStatus = getInstanceDownloadStatus(id, inst);\n      currentStatuses[id] = dlStatus.statusText;\n    }\n\n    const prev = previousInstanceStatuses;\n\n    // Only fire toasts if we had a previous snapshot (skip the very first poll)\n    if (Object.keys(prev).length > 0) {\n      for (const [id, currentStatus] of Object.entries(currentStatuses)) {\n        const prevStatus = prev[id];\n        if (!prevStatus || prevStatus === currentStatus) continue;\n\n        const modelId = getInstanceModelId(instanceData[id]);\n        const shortName = modelId\n          ? (modelId.split(\"/\").pop() ?? modelId)\n          : id.slice(0, 8);\n\n        // Downloading -> non-downloading, non-failure = download complete\n        if (\n          prevStatus === \"DOWNLOADING\" &&\n          currentStatus !== \"DOWNLOADING\" &&\n          currentStatus !== \"FAILED\"\n        ) {\n          addToast({\n            type: \"success\",\n            message: `Download complete: ${shortName}`,\n          });\n        }\n\n        // Loading/Warming Up -> Ready/Loaded/Running = model ready\n        if (\n          (prevStatus === \"LOADING\" || prevStatus === \"WARMING UP\") &&\n          (currentStatus === \"READY\" ||\n            currentStatus === \"LOADED\" ||\n            currentStatus === \"RUNNING\")\n        ) {\n          addToast({ type: \"success\", message: `Model ready: ${shortName}` });\n        }\n\n        // Any -> Failed\n        if (prevStatus !== \"FAILED\" && currentStatus === \"FAILED\") {\n          addToast({ type: \"error\", message: `Model failed: ${shortName}` });\n        }\n\n        // Any -> Shutdown\n        if (prevStatus !== \"SHUTDOWN\" && currentStatus === \"SHUTDOWN\") {\n          addToast({ type: \"info\", message: `Model shut down: ${shortName}` });\n        }\n      }\n    }\n\n    previousInstanceStatuses = currentStatuses;\n  });\n\n  // ── Connection status toasts ──\n  let previousConnectionStatus: boolean | null = null;\n\n  $effect(() => {\n    const connected = isConnected();\n    if (previousConnectionStatus !== null) {\n      if (previousConnectionStatus && !connected) {\n        addToast({\n          type: \"warning\",\n          message: \"Connection to server lost\",\n          persistent: true,\n        });\n      } else if (!previousConnectionStatus && connected) {\n        dismissByMessage(\"Connection to server lost\");\n        addToast({ type: \"success\", message: \"Connection restored\" });\n      }\n    }\n    previousConnectionStatus = connected;\n  });\n\n  const suggestedPrompts = [\n    \"Write a poem about the ocean\",\n    \"Explain quantum computing simply\",\n    \"Help me debug my code\",\n    \"Tell me a creative story\",\n  ];\n\n  // ── Seamless chat: launch models from chat view ──\n  type ChatLaunchState =\n    | \"idle\"\n    | \"launching\"\n    | \"downloading\"\n    | \"loading\"\n    | \"ready\";\n  let chatLaunchState = $state<ChatLaunchState>(\"idle\");\n  let pendingChatModelId = $state<string | null>(null);\n  let selectedChatCategory = $state<string | null>(null);\n  // Guard: when true, the restore $effect must not override chatLaunchState.\n  // Set by handleNewChat/handleGoHome; cleared when the user picks a model.\n  let userForcedIdle = $state(false);\n\n  // Restore chat launch state when switching conversations\n  $effect(() => {\n    const currentModel = selectedChatModel();\n    // When the user explicitly requested the model selector (New Chat / Go Home),\n    // skip restoring state so the selector stays visible.\n    if (userForcedIdle) return;\n    if (!currentModel) {\n      if (chatStarted && chatLaunchState !== \"idle\") {\n        chatLaunchState = \"idle\";\n        pendingChatModelId = null;\n        selectedChatCategory = null;\n      }\n      return;\n    }\n\n    // Model is already running — no progress to show\n    if (hasRunningInstance(currentModel)) {\n      if (chatLaunchState !== \"ready\") {\n        chatLaunchState = \"ready\";\n      }\n      pendingChatModelId = currentModel;\n      return;\n    }\n\n    // Model is downloading\n    const dlStatus = getModelDownloadStatus(currentModel);\n    if (dlStatus.isDownloading) {\n      chatLaunchState = \"downloading\";\n      pendingChatModelId = currentModel;\n      return;\n    }\n\n    // Model is loading or in another pre-ready state\n    for (const [, inst] of Object.entries(instanceData)) {\n      if (getInstanceModelId(inst) !== currentModel) continue;\n      const status = deriveInstanceStatus(inst);\n      if (status.statusText === \"LOADING\") {\n        chatLaunchState = \"loading\";\n        pendingChatModelId = currentModel;\n        return;\n      }\n      if (\n        status.statusText === \"WARMING UP\" ||\n        status.statusText === \"WAITING\" ||\n        status.statusText === \"INITIALIZING\" ||\n        status.statusText === \"PREPARING\"\n      ) {\n        chatLaunchState = \"launching\";\n        pendingChatModelId = currentModel;\n        return;\n      }\n    }\n\n    // Fallthrough: model exists but has no active instance/download/loading state\n    chatLaunchState = \"idle\";\n    pendingChatModelId = null;\n    selectedChatCategory = null;\n  });\n\n  // Suggested prompts per category\n  const categorySuggestedPrompts: Record<string, string[]> = {\n    coding: [\n      \"Write a Snake game in Python\",\n      \"Build a REST API with FastAPI\",\n      \"Explain how async/await works\",\n      \"Help me write unit tests for my code\",\n    ],\n    writing: [\n      \"Write a short story about time travel\",\n      \"Draft a professional email to a client\",\n      \"Create a haiku about the ocean\",\n      \"Summarize the key ideas of stoicism\",\n    ],\n    agentic: [\n      \"Plan a weekend trip to Tokyo\",\n      \"Research and compare React vs Svelte\",\n      \"Create a step-by-step guide to learn ML\",\n      \"Analyze the pros and cons of remote work\",\n    ],\n    biggest: [\n      \"Explain quantum computing simply\",\n      \"Help me brainstorm startup ideas\",\n      \"What are the key differences between TCP and UDP?\",\n      \"Write a Python script to analyze a CSV file\",\n    ],\n    auto: [\n      \"Explain quantum computing simply\",\n      \"Help me brainstorm ideas for a side project\",\n      \"Write a Python function to sort a list\",\n      \"What makes a great technical interview?\",\n    ],\n  };\n\n  // Cluster label for ChatModelSelector header\n  const chatClusterLabel = $derived.by(() => {\n    if (!data) return \"your Mac\";\n    const nodes = Object.values(data.nodes);\n    if (nodes.length === 0) return \"your Mac\";\n    if (nodes.length === 1) {\n      const node = nodes[0];\n      const name = node.system_info?.model_id || \"your Mac\";\n      const totalMem =\n        node.macmon_info?.memory?.ram_total ?? node.system_info?.memory ?? 0;\n      const memGB = Math.round(totalMem / (1024 * 1024 * 1024));\n      return `${name} ${memGB}GB`;\n    }\n    const totalMemGB = Math.round(clusterTotalMemoryGB());\n    return `cluster ${totalMemGB}GB`;\n  });\n\n  // Check if a model already has a running instance\n  function hasRunningInstance(modelId: string): boolean {\n    for (const [, inst] of Object.entries(instanceData)) {\n      const id = getInstanceModelId(inst);\n      if (id === modelId) {\n        const status = deriveInstanceStatus(inst);\n        if (\n          status.statusText === \"READY\" ||\n          status.statusText === \"LOADED\" ||\n          status.statusText === \"RUNNING\"\n        ) {\n          return true;\n        }\n      }\n    }\n    return false;\n  }\n\n  function hasExistingInstance(modelId: string): boolean {\n    for (const [, inst] of Object.entries(instanceData)) {\n      if (getInstanceModelId(inst) === modelId) return true;\n    }\n    return false;\n  }\n\n  // Pick optimal placement from previews (frontend logic)\n  // Rules: 1-node → Pipeline/Ring, multi-node with RDMA → Tensor/Jaccl (most nodes),\n  //         multi-node without RDMA → 1-node Pipeline/Ring\n  function pickOptimalPlacement(\n    previews: PlacementPreview[],\n  ): PlacementPreview | null {\n    const valid = previews.filter((p) => p.instance && !p.error);\n\n    // Check if any valid placement uses multiple nodes (indicates multi-node cluster)\n    const hasMultiNode = valid.some((p) => getPreviewNodeCount(p) > 1);\n\n    if (hasMultiNode) {\n      // Multi-node with RDMA: prefer Jaccl + Tensor with most nodes (fastest TPS)\n      const jacclTensor = valid\n        .filter(\n          (p) => p.instance_meta === \"MlxJaccl\" && p.sharding === \"Tensor\",\n        )\n        .sort((a, b) => getPreviewNodeCount(b) - getPreviewNodeCount(a));\n      if (jacclTensor.length > 0) return jacclTensor[0];\n\n      // Multi-node without RDMA: fall back to single-node Pipeline/Ring\n      const singlePipeline = valid.filter(\n        (p) =>\n          p.instance_meta === \"MlxRing\" &&\n          p.sharding === \"Pipeline\" &&\n          getPreviewNodeCount(p) === 1,\n      );\n      if (singlePipeline.length > 0) return singlePipeline[0];\n    }\n\n    // Single node (or final fallback): Pipeline/Ring with fewest nodes\n    const ringPipeline = valid\n      .filter((p) => p.instance_meta === \"MlxRing\" && p.sharding === \"Pipeline\")\n      .sort((a, b) => getPreviewNodeCount(a) - getPreviewNodeCount(b));\n    if (ringPipeline.length > 0) return ringPipeline[0];\n\n    // Last resort: any valid placement, fewest nodes\n    return (\n      valid.sort(\n        (a, b) => getPreviewNodeCount(a) - getPreviewNodeCount(b),\n      )[0] ?? null\n    );\n  }\n\n  // Launch a model for seamless chat\n  async function launchModelForChat(\n    modelId: string,\n    category: string,\n    skipCreate = false,\n  ) {\n    userForcedIdle = false;\n    pendingChatModelId = modelId;\n    selectedChatCategory = category;\n\n    // Check if already running — skip straight to chat\n    if (hasRunningInstance(modelId)) {\n      setSelectedChatModel(modelId);\n      if (!skipCreate) createConversation();\n      chatLaunchState = \"ready\";\n      return;\n    }\n\n    // Already has an instance (downloading/loading) — attach to its progress\n    if (hasExistingInstance(modelId)) {\n      setSelectedChatModel(modelId);\n      pendingChatModelId = modelId;\n      if (!skipCreate) createConversation();\n      const dlStatus = getModelDownloadStatus(modelId);\n      if (dlStatus.isDownloading) {\n        chatLaunchState = \"downloading\";\n      } else {\n        chatLaunchState = \"launching\";\n      }\n      return;\n    }\n\n    chatLaunchState = \"launching\";\n\n    try {\n      // Fetch placement previews\n      const res = await fetch(\n        `/instance/previews?model_id=${encodeURIComponent(modelId)}`,\n      );\n      if (!res.ok) {\n        addToast({\n          type: \"error\",\n          message: `Failed to get placements: ${await res.text()}`,\n        });\n        chatLaunchState = \"idle\";\n        return;\n      }\n      const data: { previews: PlacementPreview[] } = await res.json();\n      const placement = pickOptimalPlacement(data.previews);\n      if (!placement) {\n        addToast({\n          type: \"error\",\n          message: \"No valid placement found for this model\",\n        });\n        chatLaunchState = \"idle\";\n        return;\n      }\n\n      // Launch the instance\n      const launchRes = await fetch(\"/instance\", {\n        method: \"POST\",\n        headers: { \"Content-Type\": \"application/json\" },\n        body: JSON.stringify({ instance: placement.instance }),\n      });\n      if (!launchRes.ok) {\n        addToast({\n          type: \"error\",\n          message: `Failed to launch: ${await launchRes.text()}`,\n        });\n        chatLaunchState = \"idle\";\n        return;\n      }\n\n      setSelectedChatModel(modelId);\n      recordRecentLaunch(modelId);\n      if (!skipCreate) createConversation();\n      chatLaunchState = \"downloading\";\n    } catch (error) {\n      addToast({ type: \"error\", message: `Network error: ${error}` });\n      chatLaunchState = \"idle\";\n    }\n  }\n\n  // Handle auto-send: user typed without selecting a model\n  async function handleAutoSend(\n    content: string,\n    files?: {\n      id: string;\n      name: string;\n      type: string;\n      textContent?: string;\n      preview?: string;\n    }[],\n  ) {\n    // Clear forced-idle so restore effect resumes normal operation\n    userForcedIdle = false;\n\n    // Find the best already-running model by tier\n    let bestRunning: { id: string; tierIndex: number } | null = null;\n    for (const [, inst] of Object.entries(instanceData)) {\n      const modelId = getInstanceModelId(inst);\n      if (modelId === \"Unknown\" || modelId === \"Unknown Model\") continue;\n      if (!hasRunningInstance(modelId)) continue;\n      const info = models.find((m) => m.id === modelId);\n      if (!info) continue;\n      const tierIndex = getAutoTierIndex(info.base_model ?? \"\");\n      if (!bestRunning || tierIndex < bestRunning.tierIndex) {\n        bestRunning = { id: modelId, tierIndex };\n      }\n    }\n\n    // Find the best auto model that fits in available memory\n    const totalMem = availableMemoryGB();\n    const modelInfos = models.map((m) => ({\n      id: m.id,\n      name: m.name ?? \"\",\n      base_model: m.base_model ?? \"\",\n      storage_size_megabytes: m.storage_size_megabytes ?? 0,\n      capabilities: m.capabilities ?? [],\n      family: m.family ?? \"\",\n      quantization: m.quantization ?? \"\",\n    }));\n    const autoModel = pickAutoModel(modelInfos, totalMem);\n\n    // Prefer running model unless auto-pick is a strictly better tier\n    if (bestRunning) {\n      const autoTier = autoModel\n        ? getAutoTierIndex(autoModel.base_model)\n        : Infinity;\n      if (autoTier >= bestRunning.tierIndex) {\n        // Running model is same or better tier — use it directly\n        setSelectedChatModel(bestRunning.id);\n        if (!chatStarted) createConversation();\n        routeMessage(content, files);\n        return;\n      }\n    }\n\n    if (!autoModel) {\n      addToast({\n        type: \"error\",\n        message: \"No model fits in your available memory\",\n      });\n      return;\n    }\n\n    // Check if the chosen auto model is already running\n    if (hasRunningInstance(autoModel.id)) {\n      setSelectedChatModel(autoModel.id);\n      if (!chatStarted) createConversation();\n      routeMessage(content, files);\n      return;\n    }\n\n    // Already has an instance (downloading/loading) — attach to its progress\n    if (hasExistingInstance(autoModel.id)) {\n      selectedChatCategory = \"auto\";\n      setSelectedChatModel(autoModel.id);\n      pendingChatModelId = autoModel.id;\n      if (!chatStarted) createConversation();\n      pendingAutoMessage = { content, files };\n      const dlStatus = getModelDownloadStatus(autoModel.id);\n      if (dlStatus.isDownloading) {\n        chatLaunchState = \"downloading\";\n      } else {\n        chatLaunchState = \"launching\";\n      }\n      return;\n    }\n\n    // Need to launch first, then send\n    selectedChatCategory = \"auto\";\n    pendingChatModelId = autoModel.id;\n    chatLaunchState = \"launching\";\n\n    try {\n      const res = await fetch(\n        `/instance/previews?model_id=${encodeURIComponent(autoModel.id)}`,\n      );\n      if (!res.ok) {\n        addToast({\n          type: \"error\",\n          message: `Failed to get placements: ${await res.text()}`,\n        });\n        chatLaunchState = \"idle\";\n        return;\n      }\n      const data: { previews: PlacementPreview[] } = await res.json();\n      const placement = pickOptimalPlacement(data.previews);\n      if (!placement) {\n        addToast({ type: \"error\", message: \"No valid placement found\" });\n        chatLaunchState = \"idle\";\n        return;\n      }\n\n      const launchRes = await fetch(\"/instance\", {\n        method: \"POST\",\n        headers: { \"Content-Type\": \"application/json\" },\n        body: JSON.stringify({ instance: placement.instance }),\n      });\n      if (!launchRes.ok) {\n        addToast({\n          type: \"error\",\n          message: `Failed to launch: ${await launchRes.text()}`,\n        });\n        chatLaunchState = \"idle\";\n        return;\n      }\n\n      setSelectedChatModel(autoModel.id);\n      recordRecentLaunch(autoModel.id);\n      if (!chatStarted) createConversation();\n      chatLaunchState = \"downloading\";\n\n      // Queue the message to send once model is ready\n      pendingAutoMessage = { content, files };\n    } catch (error) {\n      addToast({ type: \"error\", message: `Network error: ${error}` });\n      chatLaunchState = \"idle\";\n    }\n  }\n\n  // Pending message to send after auto-launch completes\n  let pendingAutoMessage = $state<{\n    content: string;\n    files?: {\n      id: string;\n      name: string;\n      type: string;\n      textContent?: string;\n      preview?: string;\n    }[];\n  } | null>(null);\n\n  // Best running model by tier (for auto-pick display)\n  const bestRunningModelId = $derived.by(() => {\n    let best: { id: string; tierIndex: number } | null = null;\n    for (const [, inst] of Object.entries(instanceData)) {\n      const modelId = getInstanceModelId(inst);\n      if (modelId === \"Unknown\" || modelId === \"Unknown Model\") continue;\n      if (!hasRunningInstance(modelId)) continue;\n      const info = models.find((m) => m.id === modelId);\n      if (!info) continue;\n      const tierIndex = getAutoTierIndex(info.base_model ?? \"\");\n      if (!best || tierIndex < best.tierIndex) {\n        best = { id: modelId, tierIndex };\n      }\n    }\n    return best?.id ?? null;\n  });\n\n  // Track chat launch progress (download + loading)\n  const chatLaunchDownload = $derived.by(() => {\n    if (\n      !pendingChatModelId ||\n      (chatLaunchState !== \"downloading\" && chatLaunchState !== \"launching\")\n    )\n      return null;\n    const status = getModelDownloadStatus(pendingChatModelId);\n    if (status.isDownloading) return status.progress;\n    return null;\n  });\n\n  const chatLaunchLoadProgress = $derived.by(() => {\n    if (\n      !pendingChatModelId ||\n      chatLaunchState === \"idle\" ||\n      chatLaunchState === \"ready\"\n    )\n      return null;\n    let layersLoaded = 0,\n      totalLayers = 0;\n    for (const [, inst] of Object.entries(instanceData)) {\n      if (getInstanceModelId(inst) !== pendingChatModelId) continue;\n      const status = deriveInstanceStatus(inst);\n      if (\n        status.statusText === \"LOADING\" &&\n        status.totalLayers &&\n        status.totalLayers > 0\n      ) {\n        layersLoaded += status.layersLoaded ?? 0;\n        totalLayers += status.totalLayers;\n      }\n    }\n    if (totalLayers === 0) return null;\n    return {\n      layersLoaded,\n      totalLayers,\n      percentage: (layersLoaded / totalLayers) * 100,\n    };\n  });\n\n  // Auto-advance chat launch state based on instance status\n  $effect(() => {\n    if (!pendingChatModelId || chatLaunchState === \"idle\") return;\n\n    // Check if model is now ready\n    if (hasRunningInstance(pendingChatModelId)) {\n      chatLaunchState = \"ready\";\n      // Send pending auto message if any\n      if (pendingAutoMessage) {\n        const msg = pendingAutoMessage;\n        pendingAutoMessage = null;\n        routeMessage(msg.content, msg.files);\n      }\n      return;\n    }\n\n    // If already ready (set by restore effect), don't downgrade state\n    if (chatLaunchState === \"ready\") return;\n\n    // Check if currently loading\n    if (chatLaunchLoadProgress) {\n      chatLaunchState = \"loading\";\n      return;\n    }\n\n    // Check if currently downloading\n    if (chatLaunchDownload) {\n      chatLaunchState = \"downloading\";\n    }\n  });\n\n  // Check if any instance is running (for showing model selector vs chat)\n  const hasAnyRunningInstance = $derived(() => {\n    for (const [, inst] of Object.entries(instanceData)) {\n      const status = deriveInstanceStatus(inst);\n      if (\n        status.statusText === \"READY\" ||\n        status.statusText === \"LOADED\" ||\n        status.statusText === \"RUNNING\"\n      ) {\n        return true;\n      }\n    }\n    return false;\n  });\n\n  // Handle model selection from ChatModelSelector\n  function handleChatModelSelect(modelId: string, category: string) {\n    launchModelForChat(modelId, category);\n  }\n\n  // Handle \"+ Add Model\" from ChatModelSelector\n  function handleChatAddModel() {\n    modelPickerContext = \"chat\";\n    isModelPickerOpen = true;\n  }\n\n  // Track which context opened the model picker (dashboard launch vs chat selection)\n  let modelPickerContext = $state<\"dashboard\" | \"chat\">(\"dashboard\");\n\n  // Open the model picker from a chat context (e.g. clicking the model button in ChatForm)\n  function openChatModelPicker() {\n    modelPickerContext = \"chat\";\n    isModelPickerOpen = true;\n  }\n\n  // Handle model selection from the picker when opened from chat context\n  function handleChatPickerSelect(modelId: string) {\n    setSelectedChatModel(modelId);\n    selectPreviewModel(modelId);\n    userForcedIdle = false;\n    isModelPickerOpen = false;\n  }\n\n  // Unified send handler: sends if model running, auto-launches if not\n  function handleChatSend(\n    content: string,\n    files?: {\n      id: string;\n      name: string;\n      type: string;\n      textContent?: string;\n      preview?: string;\n    }[],\n  ) {\n    const model = selectedChatModel();\n\n    // Model is selected and running — send directly\n    if (model && hasRunningInstance(model)) {\n      chatLaunchState = \"ready\";\n      routeMessage(content, files);\n      return;\n    }\n\n    // Model is selected but NOT running — launch it, queue the message\n    if (model) {\n      pendingAutoMessage = { content, files };\n      userForcedIdle = false;\n      launchModelForChat(model, \"picker\", messages().length > 0);\n      return;\n    }\n\n    // No model selected — fall through to auto-pick\n    handleAutoSend(content, files);\n  }\n\n  // Helper to get the number of nodes in a placement preview\n  function getPreviewNodeCount(preview: PlacementPreview): number {\n    if (!preview.memory_delta_by_node) return 0;\n    // Count nodes that have non-zero memory delta (i.e. nodes actually used)\n    return Object.entries(preview.memory_delta_by_node).filter(\n      ([_, delta]) => delta > 0,\n    ).length;\n  }\n\n  // Available min nodes options based on topology (like old dashboard)\n  const availableMinNodes = $derived(Math.max(1, nodeCount));\n\n  // Compute which min node values have valid previews for the current model/sharding/instance type\n  // A minNodes value N is valid if there exists a placement with nodeCount >= N\n  // Note: previewsData already contains previews for the selected model (fetched via API)\n  const validMinNodeCounts = $derived(() => {\n    if (!selectedModelId || previewsData.length === 0) {\n      // If no model selected or no previews, allow all node counts (UI shows all as clickable)\n      return new Set(\n        Array.from({ length: availableMinNodes }, (_, i) => i + 1),\n      );\n    }\n\n    // Find the max node count among valid placements for this model/sharding/instance\n    // (model_id filter not needed since previewsData is already for selected model)\n    let maxValidNodes = 0;\n    for (const preview of previewsData) {\n      if (preview.sharding !== selectedSharding) continue;\n      if (!matchesSelectedRuntime(preview.instance_meta)) continue;\n      if (preview.error !== null) continue;\n      if (!preview.memory_delta_by_node) continue;\n\n      const previewNodes = getPreviewNodeCount(preview);\n      if (previewNodes > maxValidNodes) {\n        maxValidNodes = previewNodes;\n      }\n    }\n\n    // All values from 1 to maxValidNodes are valid (since there's a placement with >= that many nodes)\n    if (maxValidNodes === 0) return new Set<number>();\n    return new Set(Array.from({ length: maxValidNodes }, (_, i) => i + 1));\n  });\n\n  // Get ALL filtered previews based on current settings (matching minimum nodes)\n  // Note: previewsData already contains previews for the selected model (fetched via API)\n  // Backend handles node_ids filtering, we filter by sharding/instance type and min nodes\n  const filteredPreviews = $derived(() => {\n    if (!selectedModelId || previewsData.length === 0) return [];\n\n    // Find previews matching sharding/instance type (model_id filter not needed since previewsData is already for selected model)\n    const matchingPreviews = previewsData.filter(\n      (p: PlacementPreview) =>\n        p.sharding === selectedSharding &&\n        matchesSelectedRuntime(p.instance_meta) &&\n        p.error === null &&\n        p.memory_delta_by_node !== null,\n    );\n\n    // Filter to previews with node count >= selectedMinNodes, sorted by node count (ascending)\n    return matchingPreviews\n      .filter(\n        (p: PlacementPreview) => getPreviewNodeCount(p) >= selectedMinNodes,\n      )\n      .sort(\n        (a: PlacementPreview, b: PlacementPreview) =>\n          getPreviewNodeCount(a) - getPreviewNodeCount(b),\n      );\n  });\n\n  // Get the first filtered preview (for launch function compatibility)\n  const filteredPreview = $derived(() => filteredPreviews()[0] ?? null);\n\n  // Auto-update selectedMinNodes when node count changes (default to 1 = show all placements)\n  $effect(() => {\n    const maxNodes = availableMinNodes;\n    if (!minNodesInitialized && maxNodes > 0) {\n      // On initial load, default to 1 (minimum) to show all valid placements\n      selectedMinNodes = 1;\n      minNodesInitialized = true;\n    } else if (selectedMinNodes > maxNodes) {\n      // If current selection exceeds available nodes, cap it\n      selectedMinNodes = maxNodes;\n    }\n  });\n\n  // Auto-adjust selectedMinNodes to a valid value when it becomes invalid\n  $effect(() => {\n    const valid = validMinNodeCounts();\n    if (valid.size > 0 && !valid.has(selectedMinNodes)) {\n      // Find the smallest valid count >= current selection, or the largest valid count\n      const validArray = Array.from(valid).sort((a, b) => a - b);\n      const nextValid =\n        validArray.find((n) => n >= selectedMinNodes) ??\n        validArray[validArray.length - 1];\n      if (nextValid !== undefined) {\n        selectedMinNodes = nextValid;\n      }\n    }\n  });\n\n  // Calculate total memory usage across all nodes\n  const clusterMemory = $derived(() => {\n    if (!data) return { used: 0, total: 0 };\n    return Object.values(data.nodes).reduce(\n      (acc, n) => {\n        const total =\n          n.macmon_info?.memory?.ram_total ?? n.system_info?.memory ?? 0;\n        const used = n.macmon_info?.memory?.ram_usage ?? 0;\n        return { used: acc.used + used, total: acc.total + total };\n      },\n      { used: 0, total: 0 },\n    );\n  });\n</script>\n\n{#snippet clusterWarnings()}\n  {#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}\n    <div class=\"absolute top-4 left-4 flex flex-col gap-2 z-40\">\n      {#if tbBridgeCycles.length > 0}\n        {@const cycle = tbBridgeCycles[0]}\n        {@const serviceName = getTbBridgeServiceName(cycle)}\n        {@const disableCmd = `sudo networksetup -setnetworkserviceenabled \"${serviceName}\" off`}\n        <div class=\"group relative\" role=\"alert\">\n          <div\n            class=\"flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help\"\n          >\n            <svg\n              class=\"w-5 h-5 text-yellow-400 flex-shrink-0\"\n              fill=\"none\"\n              viewBox=\"0 0 24 24\"\n              stroke=\"currentColor\"\n              stroke-width=\"2\"\n            >\n              <path\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n                d={warningIconPath}\n              />\n            </svg>\n            <span class=\"text-sm font-mono text-yellow-200\">\n              THUNDERBOLT BRIDGE CYCLE DETECTED\n            </span>\n          </div>\n\n          <!-- Tooltip on hover -->\n          <div\n            class=\"absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg\"\n          >\n            <p class=\"text-xs text-white/80 mb-2\">\n              A network routing cycle was detected between nodes connected via\n              Thunderbolt Bridge. This can cause connectivity issues.\n            </p>\n            <p class=\"text-xs text-white/60 mb-2\">\n              <span class=\"text-yellow-300\">Affected nodes:</span>\n              {cycle.map(getNodeName).join(\" → \")}\n            </p>\n            <p class=\"text-xs text-white/60 mb-1\">\n              <span class=\"text-yellow-300\">To fix:</span> Disable the Thunderbolt\n              Bridge on one of the affected nodes:\n            </p>\n            <button\n              type=\"button\"\n              onclick={() => copyToClipboard(disableCmd)}\n              class=\"w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy\"\n              title=\"Click to copy\"\n            >\n              <span class=\"flex-1\">{disableCmd}</span>\n              <svg\n                class=\"w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors\"\n                fill=\"none\"\n                viewBox=\"0 0 24 24\"\n                stroke=\"currentColor\"\n                stroke-width=\"2\"\n              >\n                {#if copiedCommand}\n                  <path\n                    stroke-linecap=\"round\"\n                    stroke-linejoin=\"round\"\n                    d=\"M5 13l4 4L19 7\"\n                  />\n                {:else}\n                  <path\n                    stroke-linecap=\"round\"\n                    stroke-linejoin=\"round\"\n                    d=\"M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z\"\n                  />\n                {/if}\n              </svg>\n            </button>\n          </div>\n        </div>\n      {/if}\n\n      {#if macosVersionMismatch}\n        <div class=\"group relative\" role=\"alert\">\n          <div\n            class=\"flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help\"\n          >\n            <svg\n              class=\"w-5 h-5 text-yellow-400 flex-shrink-0\"\n              fill=\"none\"\n              viewBox=\"0 0 24 24\"\n              stroke=\"currentColor\"\n              stroke-width=\"2\"\n            >\n              <path\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n                d={warningIconPath}\n              />\n            </svg>\n            <span class=\"text-sm font-mono text-yellow-200\">\n              INCOMPATIBLE macOS VERSIONS\n            </span>\n          </div>\n\n          <!-- Tooltip on hover -->\n          <div\n            class=\"absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg\"\n          >\n            <p class=\"text-xs text-white/80 mb-2\">\n              Nodes in this cluster are running different macOS versions. This\n              may cause inference compatibility issues.\n            </p>\n            <div class=\"text-xs text-white/60 mb-2\">\n              <span class=\"text-yellow-300\">Node versions:</span>\n              {#each macosVersionMismatch as node}\n                <div class=\"ml-2\">\n                  {node.friendlyName} — macOS {node.version} ({node.buildVersion})\n                </div>\n              {/each}\n            </div>\n            <p class=\"text-xs text-white/60\">\n              <span class=\"text-yellow-300\">Suggested action:</span> Update all nodes\n              to the same macOS version for best compatibility.\n            </p>\n          </div>\n        </div>\n      {/if}\n\n      {#if tb5WithoutRdma && !tb5InfoDismissed}\n        <div class=\"group relative\" role=\"status\">\n          <div\n            class=\"flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help\"\n          >\n            <svg\n              class=\"w-5 h-5 text-yellow-400 flex-shrink-0\"\n              fill=\"none\"\n              viewBox=\"0 0 24 24\"\n              stroke=\"currentColor\"\n              stroke-width=\"2\"\n            >\n              <path\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n                d={warningIconPath}\n              />\n            </svg>\n            <span class=\"text-sm font-mono text-yellow-200\">\n              RDMA NOT ENABLED\n            </span>\n            <button\n              type=\"button\"\n              onclick={() => (tb5InfoDismissed = true)}\n              class=\"ml-1 text-yellow-300/60 hover:text-yellow-200 transition-colors cursor-pointer\"\n              title=\"Dismiss\"\n            >\n              <svg\n                class=\"w-4 h-4\"\n                fill=\"none\"\n                viewBox=\"0 0 24 24\"\n                stroke=\"currentColor\"\n                stroke-width=\"2\"\n              >\n                <path\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                  d=\"M6 18L18 6M6 6l12 12\"\n                />\n              </svg>\n            </button>\n          </div>\n          <!-- Tooltip on hover -->\n          <div\n            class=\"absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg\"\n          >\n            <p class=\"text-xs text-white/80 mb-2\">\n              Thunderbolt 5 hardware detected on multiple nodes. Enable RDMA for\n              significantly faster inter-node communication.\n            </p>\n            <p class=\"text-xs text-white/60 mb-1.5\">\n              <span class=\"text-yellow-300\">To enable:</span>\n            </p>\n            <ol\n              class=\"text-xs text-white/60 list-decimal list-inside space-y-0.5 mb-1.5\"\n            >\n              <li>Connect nodes with TB5 cables</li>\n              <li>Boot to Recovery (hold power 10s → Options)</li>\n              <li>\n                Run\n                <code class=\"text-yellow-300 bg-yellow-400/10 px-1 rounded\"\n                  >rdma_ctl enable</code\n                >\n              </li>\n              <li>Reboot</li>\n            </ol>\n            <p class=\"text-xs text-white/40\">\n              Requires macOS 26.2+, TB5 cables, and matching OS versions.\n            </p>\n          </div>\n        </div>\n      {/if}\n\n      {#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}\n        <div class=\"group relative\" role=\"alert\">\n          <div\n            class=\"flex items-center gap-2 px-3 py-2 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm cursor-help\"\n          >\n            <svg\n              class=\"w-5 h-5 text-red-400 flex-shrink-0\"\n              fill=\"none\"\n              viewBox=\"0 0 24 24\"\n              stroke=\"currentColor\"\n              stroke-width=\"2\"\n            >\n              <path\n                stroke-linecap=\"round\"\n                stroke-linejoin=\"round\"\n                d={warningIconPath}\n              />\n            </svg>\n            <span class=\"text-sm font-mono text-red-200\">\n              RDMA INCOMPATIBLE PORT\n            </span>\n            <button\n              type=\"button\"\n              onclick={() => (macStudioEn2Dismissed = true)}\n              class=\"ml-1 text-red-300/60 hover:text-red-200 transition-colors cursor-pointer\"\n              title=\"Dismiss\"\n            >\n              <svg\n                class=\"w-4 h-4\"\n                fill=\"none\"\n                viewBox=\"0 0 24 24\"\n                stroke=\"currentColor\"\n                stroke-width=\"2\"\n              >\n                <path\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                  d=\"M6 18L18 6M6 6l12 12\"\n                />\n              </svg>\n            </button>\n          </div>\n\n          <!-- Expanded tooltip on hover -->\n          <div\n            class=\"absolute top-full left-0 mt-2 w-96 p-4 rounded border border-red-500/30 bg-[#1a1a1a]/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg\"\n          >\n            <p class=\"text-xs text-white/80 mb-3\">\n              The Thunderbolt 5 port next to the Ethernet port on Mac Studio\n              does\n              <span class=\"text-red-400 font-semibold\">not support RDMA</span>.\n              Move the cable to one of the other three TB5 ports.\n            </p>\n\n            <div class=\"text-xs text-white/60 mb-3\">\n              <span class=\"text-red-300\">Affected:</span>\n              {#each macStudioEn2RdmaWarning as conn}\n                <div class=\"ml-2 mt-0.5\">\n                  <span class=\"text-white/80\">{conn.nodeName}</span>\n                  <span class=\"text-white/30\">&rarr;</span>\n                  <span class=\"text-white/60\">{conn.peerNodeName}</span>\n                  <span class=\"text-white/30 ml-1\">(en2)</span>\n                </div>\n              {/each}\n            </div>\n\n            <!-- Mac Studio back panel illustration -->\n            <div class=\"bg-black/40 rounded p-3 mb-3\">\n              <p\n                class=\"text-[10px] font-mono text-white/30 uppercase tracking-wider mb-2\"\n              >\n                Mac Studio — Rear Panel\n              </p>\n              <svg\n                viewBox=\"0 0 320 72\"\n                class=\"w-full\"\n                xmlns=\"http://www.w3.org/2000/svg\"\n              >\n                <rect\n                  x=\"1\"\n                  y=\"1\"\n                  width=\"318\"\n                  height=\"70\"\n                  rx=\"6\"\n                  ry=\"6\"\n                  fill=\"none\"\n                  stroke=\"rgba(255,255,255,0.12)\"\n                  stroke-width=\"1\"\n                />\n                <!-- TB5 port 1 -->\n                <rect\n                  x=\"24\"\n                  y=\"22\"\n                  width=\"28\"\n                  height=\"14\"\n                  rx=\"4\"\n                  fill=\"none\"\n                  stroke=\"rgba(255,255,255,0.3)\"\n                  stroke-width=\"1\"\n                />\n                <text\n                  x=\"38\"\n                  y=\"52\"\n                  text-anchor=\"middle\"\n                  fill=\"rgba(255,255,255,0.25)\"\n                  style=\"font-size:7px;font-family:ui-monospace,monospace;\"\n                  >TB5</text\n                >\n                <!-- TB5 port 2 -->\n                <rect\n                  x=\"62\"\n                  y=\"22\"\n                  width=\"28\"\n                  height=\"14\"\n                  rx=\"4\"\n                  fill=\"none\"\n                  stroke=\"rgba(255,255,255,0.3)\"\n                  stroke-width=\"1\"\n                />\n                <text\n                  x=\"76\"\n                  y=\"52\"\n                  text-anchor=\"middle\"\n                  fill=\"rgba(255,255,255,0.25)\"\n                  style=\"font-size:7px;font-family:ui-monospace,monospace;\"\n                  >TB5</text\n                >\n                <!-- TB5 port 3 -->\n                <rect\n                  x=\"100\"\n                  y=\"22\"\n                  width=\"28\"\n                  height=\"14\"\n                  rx=\"4\"\n                  fill=\"none\"\n                  stroke=\"rgba(255,255,255,0.3)\"\n                  stroke-width=\"1\"\n                />\n                <text\n                  x=\"114\"\n                  y=\"52\"\n                  text-anchor=\"middle\"\n                  fill=\"rgba(255,255,255,0.25)\"\n                  style=\"font-size:7px;font-family:ui-monospace,monospace;\"\n                  >TB5</text\n                >\n                <!-- TB5 port 4: INCOMPATIBLE (en2) — equally spaced with ports 1-3 -->\n                <rect\n                  x=\"138\"\n                  y=\"22\"\n                  width=\"28\"\n                  height=\"14\"\n                  rx=\"4\"\n                  fill=\"rgba(239,68,68,0.1)\"\n                  stroke=\"rgba(239,68,68,0.7)\"\n                  stroke-width=\"1.5\"\n                />\n                <line\n                  x1=\"142\"\n                  y1=\"25\"\n                  x2=\"162\"\n                  y2=\"33\"\n                  stroke=\"rgba(239,68,68,0.8)\"\n                  stroke-width=\"1.5\"\n                  stroke-linecap=\"round\"\n                />\n                <line\n                  x1=\"162\"\n                  y1=\"25\"\n                  x2=\"142\"\n                  y2=\"33\"\n                  stroke=\"rgba(239,68,68,0.8)\"\n                  stroke-width=\"1.5\"\n                  stroke-linecap=\"round\"\n                />\n                <text\n                  x=\"152\"\n                  y=\"52\"\n                  text-anchor=\"middle\"\n                  fill=\"rgba(239,68,68,0.6)\"\n                  style=\"font-size:7px;font-family:ui-monospace,monospace;font-weight:600;\"\n                  >en2</text\n                >\n                <!-- Ethernet port -->\n                <rect\n                  x=\"196\"\n                  y=\"19\"\n                  width=\"24\"\n                  height=\"20\"\n                  rx=\"2\"\n                  fill=\"none\"\n                  stroke=\"rgba(255,255,255,0.2)\"\n                  stroke-width=\"1\"\n                />\n                <rect\n                  x=\"200\"\n                  y=\"23\"\n                  width=\"16\"\n                  height=\"12\"\n                  rx=\"1\"\n                  fill=\"none\"\n                  stroke=\"rgba(255,255,255,0.12)\"\n                  stroke-width=\"0.75\"\n                />\n                <text\n                  x=\"208\"\n                  y=\"52\"\n                  text-anchor=\"middle\"\n                  fill=\"rgba(255,255,255,0.25)\"\n                  style=\"font-size:7px;font-family:ui-monospace,monospace;\"\n                  >ETH</text\n                >\n                <!-- Green checkmarks on working ports -->\n                <circle\n                  cx=\"38\"\n                  cy=\"62\"\n                  r=\"3\"\n                  fill=\"none\"\n                  stroke=\"rgba(74,222,128,0.5)\"\n                  stroke-width=\"0.75\"\n                />\n                <circle\n                  cx=\"76\"\n                  cy=\"62\"\n                  r=\"3\"\n                  fill=\"none\"\n                  stroke=\"rgba(74,222,128,0.5)\"\n                  stroke-width=\"0.75\"\n                />\n                <circle\n                  cx=\"114\"\n                  cy=\"62\"\n                  r=\"3\"\n                  fill=\"none\"\n                  stroke=\"rgba(74,222,128,0.5)\"\n                  stroke-width=\"0.75\"\n                />\n              </svg>\n            </div>\n\n            <p class=\"text-xs text-white/50\">\n              <span class=\"text-green-400\">Fix:</span> Move the Thunderbolt cable\n              to any of the three leftmost ports (all support RDMA).\n            </p>\n          </div>\n        </div>\n      {/if}\n    </div>\n  {/if}\n{/snippet}\n\n{#snippet clusterWarningsCompact()}\n  {#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}\n    <div class=\"absolute top-2 left-2 flex flex-col gap-1\">\n      {#if tbBridgeCycles.length > 0}\n        <div\n          class=\"flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm\"\n          title=\"Thunderbolt Bridge cycle detected\"\n        >\n          <svg\n            class=\"w-3.5 h-3.5 text-yellow-400\"\n            fill=\"none\"\n            viewBox=\"0 0 24 24\"\n            stroke=\"currentColor\"\n            stroke-width=\"2\"\n          >\n            <path\n              stroke-linecap=\"round\"\n              stroke-linejoin=\"round\"\n              d={warningIconPath}\n            />\n          </svg>\n          <span class=\"text-[10px] font-mono text-yellow-200\">TB CYCLE</span>\n        </div>\n      {/if}\n      {#if macosVersionMismatch}\n        <div\n          class=\"flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm\"\n          title=\"Incompatible macOS versions detected\"\n        >\n          <svg\n            class=\"w-3.5 h-3.5 text-yellow-400\"\n            fill=\"none\"\n            viewBox=\"0 0 24 24\"\n            stroke=\"currentColor\"\n            stroke-width=\"2\"\n          >\n            <path\n              stroke-linecap=\"round\"\n              stroke-linejoin=\"round\"\n              d={warningIconPath}\n            />\n          </svg>\n          <span class=\"text-[10px] font-mono text-yellow-200\"\n            >macOS MISMATCH</span\n          >\n        </div>\n      {/if}\n      {#if tb5WithoutRdma && !tb5InfoDismissed}\n        <div\n          class=\"flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm\"\n          title=\"Thunderbolt 5 detected — RDMA not enabled. Enable for faster inter-node communication.\"\n        >\n          <svg\n            class=\"w-3.5 h-3.5 text-yellow-400\"\n            fill=\"none\"\n            viewBox=\"0 0 24 24\"\n            stroke=\"currentColor\"\n            stroke-width=\"2\"\n          >\n            <path\n              stroke-linecap=\"round\"\n              stroke-linejoin=\"round\"\n              d={warningIconPath}\n            />\n          </svg>\n          <span class=\"text-[10px] font-mono text-yellow-200\"\n            >RDMA NOT ENABLED</span\n          >\n        </div>\n      {/if}\n      {#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}\n        <div\n          class=\"flex items-center gap-1.5 px-2 py-1 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm\"\n          title=\"Mac Studio RDMA incompatible port (en2) — move cable to another TB5 port\"\n        >\n          <svg\n            class=\"w-3.5 h-3.5 text-red-400\"\n            fill=\"none\"\n            viewBox=\"0 0 24 24\"\n            stroke=\"currentColor\"\n            stroke-width=\"2\"\n          >\n            <path\n              stroke-linecap=\"round\"\n              stroke-linejoin=\"round\"\n              d={warningIconPath}\n            />\n          </svg>\n          <span class=\"text-[10px] font-mono text-red-200\">BAD RDMA PORT</span>\n        </div>\n      {/if}\n    </div>\n  {/if}\n{/snippet}\n\n<!-- Global event listeners for slider dragging + onboarding keyboard nav -->\n<svelte:window\n  onmousemove={handleSliderMouseMove}\n  onmouseup={handleSliderMouseUp}\n  ontouchmove={handleSliderTouchMove}\n  ontouchend={handleSliderTouchEnd}\n  onkeydown={(e) => {\n    if (!showOnboardingOverlay || stepTransitioning) return;\n    if (e.key === \"ArrowRight\" || e.key === \" \" || e.key === \"Enter\") {\n      if (onboardingStep >= 1 && onboardingStep <= 4 && showContinueButton) {\n        e.preventDefault();\n        advanceStep(onboardingStep < 4 ? onboardingStep + 1 : 6);\n      }\n    }\n  }}\n/>\n\n<div\n  class=\"relative h-screen w-full flex flex-col bg-exo-dark-gray overflow-hidden\"\n>\n  <!-- Scanline overlay -->\n  <!-- Scanline overlay -->\n  <div\n    class=\"fixed inset-0 pointer-events-none z-50 scanlines\"\n    style=\"transition: opacity 0.5s ease; opacity: {showOnboardingOverlay\n      ? 0\n      : 0.2};\"\n  ></div>\n\n  <!-- Shooting Stars Background -->\n  <div\n    class=\"shooting-stars\"\n    style=\"transition: opacity 0.5s ease; opacity: {showOnboardingOverlay\n      ? 0.4\n      : 1};\"\n  >\n    <div\n      class=\"shooting-star\"\n      style=\"top: 10%; left: 20%; --duration: 45s; --delay: 0s;\"\n    ></div>\n    <div\n      class=\"shooting-star\"\n      style=\"top: 30%; left: 65%; --duration: 45s; --delay: 15s;\"\n    ></div>\n    <div\n      class=\"shooting-star\"\n      style=\"top: 50%; left: 40%; --duration: 45s; --delay: 30s;\"\n    ></div>\n  </div>\n\n  {#if showOnboardingOverlay}\n    <!-- ═══════════════════════════════════════════════════════ -->\n    <!-- FULL-SCREEN ONBOARDING WIZARD (overlay)                -->\n    <!-- ═══════════════════════════════════════════════════════ -->\n    <div\n      class=\"absolute inset-0 flex items-center justify-center z-30 bg-exo-black\"\n      style=\"transition: opacity 0.45s cubic-bezier(0.4, 0, 0.2, 1); opacity: {onboardingFadingOut\n        ? 0\n        : 1};\"\n    >\n      {#if onboardingStep >= 1 && onboardingStep <= 4}\n        <!-- Steps 1-4: Cinematic SVG animation story -->\n        <div\n          class=\"flex flex-col items-center w-full max-w-3xl px-8\"\n          style=\"transition: opacity 0.6s cubic-bezier(0.4, 0, 0.2, 1), transform 0.6s cubic-bezier(0.4, 0, 0.2, 1); opacity: {stepTransitioning\n            ? 0\n            : 1}; transform: scale({stepTransitioning ? 0.98 : 1});\"\n        >\n          <!-- Logo + Step title -->\n          <div class=\"text-center mb-8\">\n            <!-- Logo — smoothly shrinks away when leaving step 1 -->\n            <div\n              style=\"opacity: {$logoOpacity}; max-height: {$logoOpacity *\n                80}px; overflow: hidden; transition: max-height 0.6s cubic-bezier(0.4, 0, 0.2, 1);\"\n            >\n              <img src=\"/exo-logo.png\" alt=\"exo\" class=\"w-36 mx-auto mb-10\" />\n            </div>\n\n            <!-- Title — single element, text updates instantly -->\n            <h1\n              class=\"text-2xl font-light text-white/90 tracking-wide\"\n              style=\"opacity: {$titleOpacity}; font-family: -apple-system, 'SF Pro Display', system-ui, sans-serif; letter-spacing: 0.02em;\"\n            >\n              {onboardingStep === 1\n                ? \"EXO connects all your devices into an AI supercomputer.\"\n                : stepTitle}\n            </h1>\n\n            <!-- Subtitle — uses tweened opacity, reserves space to prevent layout shift -->\n            <p\n              class=\"text-sm mt-2 text-white/40 max-w-md mx-auto\"\n              style=\"opacity: {$subtitleOpacity}; font-family: -apple-system, 'SF Pro Display', system-ui, sans-serif; font-weight: 300; min-height: 1.5em;\"\n            >\n              {#if onboardingStep === 2}\n                &nbsp;\n              {:else if onboardingStep === 3}\n                The model is automatically distributed. Each device handles a\n                piece.\n              {:else if onboardingStep === 4}\n                {stepTitle === \"exo self-heals\"\n                  ? \"exo automatically redistributes the model so inference continues without interruption.\"\n                  : \"Devices can leave anytime. Laptops close, machines restart.\"}\n              {:else}\n                &nbsp;\n              {/if}\n            </p>\n          </div>\n\n          <!-- Device display area -->\n          <div class=\"relative w-full\" style=\"height: 420px;\">\n            <!-- Device count label — fades in on step 1, fades out on step 2 -->\n            <p\n              class=\"absolute left-0 right-0 text-center text-lg text-white/50 font-light tracking-wide z-10\"\n              style=\"top: 20px; opacity: {$deviceCountOpacity}; font-family: -apple-system, 'SF Pro Display', system-ui, sans-serif; pointer-events: none;\"\n            >\n              Your EXO Network\n            </p>\n\n            <!-- Step 1: Real topology graph -->\n            {#if onboardingStep <= 1 || $topologyOpacity > 0.01}\n              <div\n                class=\"absolute inset-0 flex items-center justify-center\"\n                style=\"opacity: {$topologyOpacity}; pointer-events: {onboardingStep <=\n                1\n                  ? 'none'\n                  : 'none'};\"\n              >\n                <TopologyGraph class=\"w-full h-full\" />\n              </div>\n            {/if}\n\n            <!-- Steps 2+: Tweened SVG canvas with device pair -->\n            <svg\n              viewBox=\"0 0 700 420\"\n              class=\"w-full h-full\"\n              xmlns=\"http://www.w3.org/2000/svg\"\n              style=\"position: relative;\"\n            >\n              <!-- Device 1 (User's device) -->\n              <g\n                transform=\"translate({$device1X}, 210)\"\n                opacity={$device1Opacity}\n                style=\"transition: opacity 0.6s ease;\"\n              >\n                <DeviceIcon\n                  deviceType={userDeviceInfo.deviceType}\n                  cx={0}\n                  cy={0}\n                  size={110}\n                  ramPercent={60}\n                  uid=\"onb-d1\"\n                />\n                <text\n                  x=\"0\"\n                  y=\"-105\"\n                  text-anchor=\"middle\"\n                  fill=\"rgba(255,255,255,0.9)\"\n                  style=\"font-size: 15px; font-family: -apple-system, 'SF Pro Display', system-ui, sans-serif; font-weight: 500; letter-spacing: 0.01em;\"\n                >\n                  {userDeviceInfo.name}\n                </text>\n                <text\n                  x=\"0\"\n                  y=\"105\"\n                  text-anchor=\"middle\"\n                  style=\"font-size: 14px; font-family: 'SF Mono', ui-monospace, monospace;\"\n                >\n                  <tspan fill=\"rgba(255,215,0,0.9)\"\n                    >{userDeviceInfo.memoryGB}</tspan\n                  ><tspan fill=\"rgba(255,255,255,0.4)\">{\" \"}GB</tspan>\n                </text>\n              </g>\n\n              <!-- Device 2 (Mac Studio — simulated) -->\n              <g\n                transform=\"translate({$device2X}, 210)\"\n                opacity={$device2Opacity}\n                style=\"transition: opacity 0.6s ease;\"\n              >\n                <!-- Dashed outline to indicate simulated device -->\n                <rect\n                  x={(-110 * 1.25) / 2 - 6}\n                  y={(-110 * 0.85) / 2 - 6}\n                  width={110 * 1.25 + 12}\n                  height={110 * 0.85 + 12}\n                  rx=\"6\"\n                  fill=\"none\"\n                  stroke=\"rgba(255,255,255,0.12)\"\n                  stroke-dasharray=\"4,4\"\n                />\n                <DeviceIcon\n                  deviceType=\"mac studio\"\n                  cx={0}\n                  cy={0}\n                  size={110}\n                  ramPercent={80}\n                  uid=\"onb-d2\"\n                />\n                <text\n                  x=\"0\"\n                  y=\"-105\"\n                  text-anchor=\"middle\"\n                  fill=\"rgba(255,255,255,0.9)\"\n                  style=\"font-size: 15px; font-family: -apple-system, 'SF Pro Display', system-ui, sans-serif; font-weight: 500; letter-spacing: 0.01em;\"\n                >\n                  Mac Studio\n                </text>\n                <text\n                  x=\"0\"\n                  y=\"105\"\n                  text-anchor=\"middle\"\n                  style=\"font-size: 14px; font-family: 'SF Mono', ui-monospace, monospace;\"\n                >\n                  <tspan fill=\"rgba(255,215,0,0.9)\">{SIMULATED_STUDIO_GB}</tspan\n                  ><tspan fill=\"rgba(255,255,255,0.4)\">{\" \"}GB</tspan>\n                </text>\n                <text\n                  x=\"0\"\n                  y=\"120\"\n                  text-anchor=\"middle\"\n                  fill=\"rgba(255,255,255,0.2)\"\n                  style=\"font-size: 9px; font-family: -apple-system, 'SF Pro Display', system-ui, sans-serif; font-style: italic;\"\n                >\n                  (example)\n                </text>\n              </g>\n\n              <!-- Connection line between devices -->\n              <line\n                x1={$device1X + 85}\n                y1={210}\n                x2={$device2X - 85}\n                y2={210}\n                stroke={$connectionIsRed > 0.5\n                  ? \"rgba(220,38,38,0.7)\"\n                  : \"rgba(255,255,255,0.15)\"}\n                stroke-width=\"1.5\"\n                stroke-dasharray=\"6,6\"\n                opacity={$connectionOpacity}\n                class={$connectionIsRed > 0.5\n                  ? \"onboarding-connection-line-red\"\n                  : \"onboarding-connection-line\"}\n              />\n\n              <!-- Disconnect X mark -->\n              {#if $disconnectXOpacity > 0.01}\n                <g\n                  transform=\"translate({($device1X + $device2X) / 2}, 210)\"\n                  opacity={$disconnectXOpacity}\n                >\n                  <circle\n                    r=\"18\"\n                    fill=\"rgba(220,38,38,0.1)\"\n                    stroke=\"rgba(220,38,38,0.6)\"\n                    stroke-width=\"1.5\"\n                  />\n                  <line\n                    x1=\"-8\"\n                    y1=\"-8\"\n                    x2=\"8\"\n                    y2=\"8\"\n                    stroke=\"rgba(220,38,38,0.8)\"\n                    stroke-width=\"2.5\"\n                    stroke-linecap=\"round\"\n                  />\n                  <line\n                    x1=\"8\"\n                    y1=\"-8\"\n                    x2=\"-8\"\n                    y2=\"8\"\n                    stroke=\"rgba(220,38,38,0.8)\"\n                    stroke-width=\"2.5\"\n                    stroke-linecap=\"round\"\n                  />\n                </g>\n              {/if}\n\n              <!-- Combined memory label -->\n              <text\n                x={($device1X + $device2X) / 2}\n                y={130}\n                text-anchor=\"middle\"\n                fill=\"rgba(255,215,0,0.7)\"\n                style=\"font-size: 14px; font-family: 'SF Mono', ui-monospace, monospace; font-weight: 500; letter-spacing: 0.02em;\"\n                opacity={$combinedLabelOpacity}\n              >\n                {onboardingCombinedGB} GB combined\n              </text>\n\n              <!-- Step 2: Models unlocked — staggered slide-up + yellow glow -->\n              {#if unlockedModels.length > 0 && $chipPhase > 0.01}\n                {@const centerX = ($device1X + $device2X) / 2}\n                {@const chipW = 140}\n                {@const chipH = 30}\n                {@const chipGap = 12}\n                {@const totalW =\n                  unlockedModels.length * chipW +\n                  (unlockedModels.length - 1) * chipGap}\n                {@const startX = centerX - totalW / 2}\n                <!-- SVG filter for yellow glow -->\n                <defs>\n                  <filter\n                    id=\"chip-glow\"\n                    x=\"-50%\"\n                    y=\"-50%\"\n                    width=\"200%\"\n                    height=\"200%\"\n                  >\n                    <feGaussianBlur\n                      in=\"SourceGraphic\"\n                      stdDeviation=\"4\"\n                      result=\"blur\"\n                    />\n                    <feColorMatrix\n                      in=\"blur\"\n                      type=\"matrix\"\n                      values=\"1 0.8 0 0 0  0.8 0.7 0 0 0  0 0 0 0 0  0 0 0 0.4 0\"\n                      result=\"glow\"\n                    />\n                    <feMerge>\n                      <feMergeNode in=\"glow\" />\n                      <feMergeNode in=\"SourceGraphic\" />\n                    </feMerge>\n                  </filter>\n                </defs>\n                <!-- Header slides up + fades with yellow tint -->\n                {@const headerProgress = Math.min(1, $chipPhase)}\n                {@const headerY = 332 + 12 * (1 - headerProgress)}\n                {@const yellowR = 234}\n                {@const yellowG = 179}\n                {@const yellowB = 8}\n                <text\n                  x={centerX}\n                  y={headerY}\n                  text-anchor=\"middle\"\n                  dominant-baseline=\"middle\"\n                  fill=\"rgba({yellowR},{yellowG},{yellowB},{0.5 *\n                    headerProgress})\"\n                  opacity={headerProgress}\n                  style=\"font-size: 10px; font-family: -apple-system, 'SF Pro Display', system-ui, sans-serif; font-weight: 500; letter-spacing: 0.1em;\"\n                >\n                  NEW MODELS UNLOCKED\n                </text>\n                <!-- Model chips — staggered slide-up + scale + yellow highlight -->\n                {#each unlockedModels as model, i}\n                  {@const stagger = i * 0.6}\n                  {@const progress = Math.max(\n                    0,\n                    Math.min(1, $chipPhase - stagger),\n                  )}\n                  {@const modelName = (\n                    model.name ||\n                    model.id.split(\"/\").pop() ||\n                    \"\"\n                  ).slice(0, 18)}\n                  {@const modelSize = Math.round(getModelSizeGB(model))}\n                  {@const slideY = 16 * (1 - progress)}\n                  {@const chipScale = 0.85 + 0.15 * progress}\n                  <!-- Yellow highlight peaks at ~0.6 progress then settles to subtle -->\n                  {@const highlightPeak =\n                    progress < 0.6\n                      ? progress / 0.6\n                      : 1 - ((progress - 0.6) / 0.4) * 0.6}\n                  {@const borderYellow = 0.15 + 0.35 * highlightPeak}\n                  {@const fillYellow = 0.02 + 0.06 * highlightPeak}\n                  {#if progress > 0}\n                    <g\n                      transform=\"translate({startX +\n                        i * (chipW + chipGap) +\n                        chipW / 2}, {358 + slideY}) scale({chipScale})\"\n                      opacity={progress}\n                      filter={highlightPeak > 0.3 ? \"url(#chip-glow)\" : \"none\"}\n                    >\n                      <rect\n                        x={-chipW / 2}\n                        y={-chipH / 2}\n                        width={chipW}\n                        height={chipH}\n                        rx=\"15\"\n                        fill=\"rgba({yellowR},{yellowG},{yellowB},{fillYellow})\"\n                        stroke=\"rgba({yellowR},{yellowG},{yellowB},{borderYellow})\"\n                        stroke-width=\"1\"\n                      />\n                      <text\n                        x=\"0\"\n                        y={modelSize ? -4 : 1}\n                        text-anchor=\"middle\"\n                        dominant-baseline=\"middle\"\n                        fill=\"rgba(255,255,255,{0.5 + 0.3 * progress})\"\n                        style=\"font-size: 10px; font-family: 'SF Mono', ui-monospace, monospace; font-weight: 500;\"\n                      >\n                        {modelName}\n                      </text>\n                      {#if modelSize}\n                        <text\n                          x=\"0\"\n                          y=\"8\"\n                          text-anchor=\"middle\"\n                          dominant-baseline=\"middle\"\n                          fill=\"rgba(255,255,255,{0.15 + 0.15 * progress})\"\n                          style=\"font-size: 8px; font-family: 'SF Mono', ui-monospace, monospace; font-weight: 400;\"\n                        >\n                          {modelSize} GB\n                        </text>\n                      {/if}\n                    </g>\n                  {/if}\n                {/each}\n              {/if}\n\n              <!-- Model block (unified or split) -->\n              {#if $modelBlockOpacity > 0.01}\n                {#if $modelSplitProgress < 0.05}\n                  <!-- Unified model block — centers on device1 when device2 is hidden -->\n                  {@const modelCenterX =\n                    $device2Opacity > 0.3\n                      ? ($device1X + $device2X) / 2\n                      : $device1X}\n                  <g\n                    transform=\"translate({modelCenterX}, {$modelBlockY})\"\n                    opacity={$modelBlockOpacity}\n                  >\n                    <rect\n                      x=\"-45\"\n                      y=\"-13\"\n                      width=\"90\"\n                      height=\"26\"\n                      rx=\"6\"\n                      fill=\"rgba(180,140,0,0.08)\"\n                      stroke=\"rgba(180,140,0,0.45)\"\n                      stroke-width=\"1.5\"\n                    />\n                    <text\n                      x=\"0\"\n                      y=\"5\"\n                      text-anchor=\"middle\"\n                      fill=\"rgba(220,180,40,0.9)\"\n                      style=\"font-size: 12px; font-family: -apple-system, system-ui, sans-serif; font-weight: 500;\"\n                    >\n                      LLM\n                    </text>\n                  </g>\n                {:else}\n                  <!-- Split model halves flowing down to each device -->\n                  {@const splitX =\n                    $modelSplitProgress * (($device2X - $device1X) / 2)}\n                  {@const centerX = ($device1X + $device2X) / 2}\n                  {@const splitY = $modelBlockY + $modelSplitProgress * 80}\n\n                  <!-- Left half -> Device 1 -->\n                  <g\n                    transform=\"translate({centerX - splitX}, {splitY})\"\n                    opacity={$modelBlockOpacity}\n                  >\n                    <rect\n                      x=\"-45\"\n                      y=\"-13\"\n                      width=\"90\"\n                      height=\"26\"\n                      rx=\"6\"\n                      fill=\"rgba(180,140,0,0.08)\"\n                      stroke=\"rgba(180,140,0,0.35)\"\n                      stroke-width=\"1\"\n                    />\n                    <text\n                      x=\"0\"\n                      y=\"4\"\n                      text-anchor=\"middle\"\n                      fill=\"rgba(220,180,40,0.75)\"\n                      style=\"font-size: 11px; font-family: -apple-system, system-ui, sans-serif;\"\n                    >\n                      Shard 1/2\n                    </text>\n                  </g>\n\n                  <!-- Right half -> Device 2 -->\n                  <g\n                    transform=\"translate({centerX + splitX}, {splitY})\"\n                    opacity={$modelBlockOpacity * $device2Opacity}\n                  >\n                    <rect\n                      x=\"-45\"\n                      y=\"-13\"\n                      width=\"90\"\n                      height=\"26\"\n                      rx=\"6\"\n                      fill=\"rgba(180,140,0,0.08)\"\n                      stroke=\"rgba(180,140,0,0.35)\"\n                      stroke-width=\"1\"\n                    />\n                    <text\n                      x=\"0\"\n                      y=\"4\"\n                      text-anchor=\"middle\"\n                      fill=\"rgba(220,180,40,0.75)\"\n                      style=\"font-size: 11px; font-family: -apple-system, system-ui, sans-serif;\"\n                    >\n                      Shard 2/2\n                    </text>\n                  </g>\n                {/if}\n              {/if}\n            </svg>\n          </div>\n\n          <!-- Continue button — smooth transition, only for steps 1 and 5 -->\n          <div\n            style=\"transition: opacity 0.4s ease, transform 0.4s cubic-bezier(0.4,0,0.2,1); opacity: {showContinueButton\n              ? 1\n              : 0}; transform: translateY({showContinueButton\n              ? '0px'\n              : '12px'}); pointer-events: {showContinueButton\n              ? 'auto'\n              : 'none'}; margin-top: 0.5rem;\"\n          >\n            <button\n              type=\"button\"\n              onclick={() =>\n                advanceStep(onboardingStep < 4 ? onboardingStep + 1 : 6)}\n              class=\"inline-flex items-center gap-2.5 px-10 py-3.5 bg-exo-yellow text-exo-black text-sm font-semibold rounded-full cursor-pointer\"\n              style=\"transition: transform 0.2s ease, box-shadow 0.3s ease, filter 0.2s ease; font-family: -apple-system, 'SF Pro Display', system-ui, sans-serif; letter-spacing: 0.02em;\"\n              onmouseenter={(e) => {\n                e.currentTarget.style.filter = \"brightness(1.08)\";\n                e.currentTarget.style.boxShadow =\n                  \"0 0 30px rgba(255,215,0,0.2)\";\n              }}\n              onmouseleave={(e) => {\n                e.currentTarget.style.filter = \"brightness(1)\";\n                e.currentTarget.style.boxShadow = \"none\";\n              }}\n            >\n              Continue\n              <svg\n                class=\"w-4 h-4\"\n                fill=\"none\"\n                viewBox=\"0 0 24 24\"\n                stroke=\"currentColor\"\n                stroke-width=\"2.5\"\n              >\n                <path\n                  stroke-linecap=\"round\"\n                  stroke-linejoin=\"round\"\n                  d=\"M13 7l5 5m0 0l-5 5m5-5H6\"\n                />\n              </svg>\n            </button>\n          </div>\n        </div>\n      {:else if onboardingStep === 6}\n        <!-- Step 6: Choose a Model -->\n        <div\n          class=\"flex flex-col items-center w-full max-w-2xl px-8\"\n          style=\"opacity: 0; animation: onb-fade-in 0.5s ease forwards;\"\n        >\n          <div class=\"text-center mb-8\">\n            <h1\n              class=\"text-xl font-sans font-light text-white/90 mb-2 tracking-wide\"\n            >\n              Choose a model\n            </h1>\n            <p class=\"text-sm font-sans text-white/40\">\n              Showing recommended models for your devices ({Math.round(\n                clusterMemory().total / (1024 * 1024 * 1024),\n              )} GB memory available).\n            </p>\n          </div>\n\n          {#if onboardingError}\n            <div\n              class=\"w-full mb-6 px-4 py-3 rounded-lg border border-red-500/30 bg-red-500/10 text-sm font-mono text-red-400\"\n              in:fade={{ duration: 200 }}\n            >\n              {onboardingError}\n            </div>\n          {/if}\n\n          {#if onboardingModels.length === 0}\n            <div class=\"text-center py-8\">\n              <div class=\"text-sm text-white/40 font-sans animate-pulse\">\n                Loading models...\n              </div>\n            </div>\n          {:else}\n            <div class=\"w-full space-y-3 mb-8\">\n              {#each onboardingModels as model}\n                {@const sizeGB = getModelSizeGB(model)}\n                {@const fitsNow = hasEnoughMemory(model)}\n                {@const tags = modelTags()[model.id] || []}\n                <button\n                  type=\"button\"\n                  onclick={() => onboardingLaunchModel(model.id)}\n                  class=\"w-full flex items-center justify-between gap-4 px-5 py-4 rounded-xl border transition-all duration-200 cursor-pointer {fitsNow\n                    ? 'border-white/10 bg-white/5 hover:border-exo-yellow/50 hover:bg-exo-yellow/5'\n                    : 'border-white/10 bg-white/[0.02] hover:border-white/20 opacity-60'}\"\n                >\n                  <div class=\"flex flex-col items-start gap-1 min-w-0\">\n                    <div class=\"flex items-center gap-2\">\n                      <span\n                        class=\"text-sm font-sans font-medium text-white truncate\"\n                        >{model.name || model.id}</span\n                      >\n                      {#each tags as tag}\n                        <span\n                          class=\"text-[10px] font-sans font-medium px-1.5 py-0.5 rounded-full bg-exo-yellow/10 text-exo-yellow/80\"\n                          >{tag}</span\n                        >\n                      {/each}\n                    </div>\n                    <span class=\"text-xs font-mono text-white/40 truncate\"\n                      >{model.id}</span\n                    >\n                  </div>\n                  <div class=\"flex items-center gap-3 flex-shrink-0\">\n                    <span class=\"text-xs font-mono text-white/50\"\n                      >{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)} GB</span\n                    >\n                    <svg\n                      class=\"w-4 h-4 text-white/40\"\n                      fill=\"none\"\n                      viewBox=\"0 0 24 24\"\n                      stroke=\"currentColor\"\n                      stroke-width=\"2\"\n                    >\n                      <path\n                        stroke-linecap=\"round\"\n                        stroke-linejoin=\"round\"\n                        d=\"M9 5l7 7-7 7\"\n                      />\n                    </svg>\n                  </div>\n                </button>\n              {/each}\n            </div>\n          {/if}\n\n          <button\n            type=\"button\"\n            onclick={() => {\n              modelPickerContext = \"dashboard\";\n              isModelPickerOpen = true;\n            }}\n            class=\"text-sm font-sans text-white/40 hover:text-exo-yellow transition-colors cursor-pointer underline underline-offset-4 decoration-white/20 hover:decoration-exo-yellow/50\"\n          >\n            Browse all models\n          </button>\n        </div>\n      {:else if onboardingStep === 7}\n        <!-- Step 7: Downloading -->\n        <div\n          class=\"text-center max-w-lg px-8\"\n          style=\"opacity: 0; animation: onb-fade-in 0.5s ease forwards;\"\n        >\n          <div class=\"mb-8\">\n            <h1\n              class=\"text-xl font-sans font-light text-white/90 mb-2 tracking-wide\"\n            >\n              Downloading\n            </h1>\n            {#if onboardingModelId}\n              <p class=\"text-sm text-white/40 font-sans\">\n                {onboardingModelId.split(\"/\").pop() ?? onboardingModelId}\n              </p>\n            {/if}\n          </div>\n\n          {#if onboardingDownloadProgress}\n            <div class=\"w-full max-w-md mx-auto space-y-4\">\n              <div\n                class=\"relative h-2 bg-white/10 rounded-full overflow-hidden\"\n              >\n                <div\n                  class=\"absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow-darker rounded-full transition-all duration-500\"\n                  style=\"width: {onboardingDownloadProgress.percentage}%\"\n                ></div>\n              </div>\n              <div class=\"flex justify-between text-xs font-mono text-white/50\">\n                <span>{onboardingDownloadProgress.percentage.toFixed(1)}%</span>\n                <span\n                  >{formatBytes(onboardingDownloadProgress.downloadedBytes)} /\n                  {formatBytes(onboardingDownloadProgress.totalBytes)}</span\n                >\n              </div>\n              <div class=\"flex justify-between text-xs font-mono text-white/40\">\n                <span>{formatSpeed(onboardingDownloadProgress.speed)}</span>\n                <span>ETA: {formatEta(onboardingDownloadProgress.etaMs)}</span>\n              </div>\n            </div>\n          {:else}\n            <div class=\"w-full max-w-md mx-auto\">\n              <div\n                class=\"relative h-2 bg-white/10 rounded-full overflow-hidden\"\n              >\n                <div\n                  class=\"absolute inset-y-0 left-0 w-1/3 bg-gradient-to-r from-exo-yellow to-exo-yellow-darker rounded-full animate-pulse\"\n                ></div>\n              </div>\n              <p class=\"text-xs font-mono text-white/40 mt-4\">\n                Preparing download...\n              </p>\n            </div>\n          {/if}\n\n          <p class=\"text-xs font-sans text-white/40 mt-8\">\n            This may take a few minutes depending on your connection.\n          </p>\n        </div>\n      {:else if onboardingStep === 8}\n        <!-- Step 8: Loading into memory -->\n        <div\n          class=\"text-center max-w-lg px-8\"\n          style=\"opacity: 0; animation: onb-fade-in 0.5s ease forwards;\"\n        >\n          <div class=\"mb-6\">\n            <h1\n              class=\"text-xl font-sans font-light text-white/90 mb-2 tracking-wide\"\n            >\n              Loading into memory\n            </h1>\n            {#if onboardingModelId}\n              <p class=\"text-sm text-white/40 font-sans\">\n                {onboardingModelId.split(\"/\").pop() ?? onboardingModelId}\n              </p>\n            {/if}\n          </div>\n\n          <!-- Device icon -->\n          <div class=\"flex justify-center mb-6\">\n            <svg\n              viewBox=\"0 0 200 200\"\n              class=\"w-32 h-32\"\n              xmlns=\"http://www.w3.org/2000/svg\"\n            >\n              <DeviceIcon\n                deviceType={userDeviceInfo.deviceType}\n                cx={100}\n                cy={100}\n                size={80}\n                ramPercent={60}\n                uid=\"onb-loading\"\n              />\n            </svg>\n          </div>\n\n          {#if onboardingLoadProgress}\n            <div class=\"w-full max-w-xs mx-auto space-y-3\">\n              <div\n                class=\"relative h-2 bg-white/10 rounded-full overflow-hidden\"\n              >\n                <div\n                  class=\"absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow-darker rounded-full transition-all duration-500\"\n                  style=\"width: {onboardingLoadProgress.percentage}%\"\n                ></div>\n              </div>\n              <p class=\"text-xs text-white/40 font-mono text-center\">\n                {onboardingLoadProgress.layersLoaded} / {onboardingLoadProgress.totalLayers}\n                layers loaded\n              </p>\n            </div>\n          {:else}\n            <div class=\"flex justify-center mb-4\">\n              <div\n                class=\"w-8 h-8 border-2 border-exo-yellow/15 border-t-exo-yellow/70 rounded-full animate-spin\"\n              ></div>\n            </div>\n            <p class=\"text-sm text-white/30 font-sans\">Loading...</p>\n          {/if}\n        </div>\n      {:else if onboardingStep === 9}\n        <!-- Step 9: Ready — centered input with suggestion chips -->\n        <!-- Uses onb-fade-opacity (no transform) so fixed-position dropdown in ChatForm works correctly -->\n        <div\n          class=\"flex flex-col items-center justify-center w-full max-w-2xl px-8\"\n          style=\"opacity: 0; animation: onb-fade-opacity 0.6s ease forwards;\"\n        >\n          <img\n            src=\"/exo-logo.png\"\n            alt=\"exo\"\n            class=\"w-28 mb-6\"\n            style=\"opacity: 0.8;\"\n          />\n\n          {#if onboardingModelId}\n            <p class=\"text-sm text-white/40 font-mono mb-6\">\n              {onboardingModelId.split(\"/\").pop() ?? onboardingModelId}\n            </p>\n          {/if}\n\n          <div class=\"w-full\">\n            <ChatForm\n              placeholder=\"Ask anything\"\n              autofocus={true}\n              showHelperText={false}\n              showModelSelector={true}\n              modelTasks={modelTasks()}\n              modelCapabilities={modelCapabilities()}\n              onOpenModelPicker={openChatModelPicker}\n              onAutoSend={handleChatSend}\n            />\n          </div>\n\n          <div class=\"flex flex-wrap justify-center gap-3 mt-6\">\n            {#each suggestedPrompts as chip}\n              <button\n                type=\"button\"\n                onclick={() => {\n                  completeOnboarding();\n                  sendMessage(chip);\n                }}\n                class=\"px-4 py-2 rounded-full border border-white/10 bg-white/5 text-sm text-white/60 hover:bg-white/10 hover:text-white/80 hover:border-white/20 transition-all duration-200 cursor-pointer\"\n              >\n                {chip}\n              </button>\n            {/each}\n          </div>\n        </div>\n      {/if}\n\n      <!-- Replay / Skip — visible on all onboarding steps -->\n      <div class=\"absolute bottom-8 flex items-center gap-6\">\n        <button\n          type=\"button\"\n          onclick={() => {\n            onboardingStep = 0;\n            setTimeout(() => {\n              onboardingStep = 1;\n            }, 50);\n          }}\n          class=\"flex items-center gap-1.5 text-xs font-sans text-white/15 hover:text-white/35 transition-colors duration-300 cursor-pointer\"\n        >\n          <svg\n            width=\"12\"\n            height=\"12\"\n            viewBox=\"0 0 16 16\"\n            fill=\"none\"\n            stroke=\"currentColor\"\n            stroke-width=\"1.8\"\n            stroke-linecap=\"round\"\n            stroke-linejoin=\"round\"\n          >\n            <path d=\"M2.5 2v5h5\" />\n            <path d=\"M2.5 7a6.5 6.5 0 1 1 1.4-2.8\" />\n          </svg>\n          Replay\n        </button>\n        <button\n          type=\"button\"\n          onclick={completeOnboarding}\n          class=\"flex items-center gap-1.5 text-xs font-sans text-white/15 hover:text-white/35 transition-colors duration-300 cursor-pointer\"\n        >\n          <svg width=\"12\" height=\"12\" viewBox=\"0 0 16 16\" fill=\"currentColor\">\n            <path d=\"M3 2.5v11L9 8 3 2.5z\" />\n            <rect x=\"10.5\" y=\"2.5\" width=\"2.5\" height=\"11\" rx=\"0.5\" />\n          </svg>\n          Skip\n        </button>\n      </div>\n    </div>\n\n    <!-- Model Picker Modal (available during onboarding step 4) -->\n    {#if onboardingStep === 6}\n      <ModelPickerModal\n        isOpen={isModelPickerOpen}\n        {models}\n        {selectedModelId}\n        favorites={favoritesSet}\n        {recentModelIds}\n        hasRecents={showRecentsTab}\n        existingModelIds={new Set(models.map((m) => m.id))}\n        canModelFit={(modelId) => {\n          const model = models.find((m) => m.id === modelId);\n          return model ? hasEnoughMemory(model) : false;\n        }}\n        getModelFitStatus={(modelId): ModelMemoryFitStatus => {\n          const model = models.find((m) => m.id === modelId);\n          return model ? getModelMemoryFitStatus(model) : \"too_large\";\n        }}\n        onSelect={(modelId) => {\n          isModelPickerOpen = false;\n          onboardingLaunchModel(modelId);\n        }}\n        onClose={() => (isModelPickerOpen = false)}\n        onToggleFavorite={toggleFavorite}\n        onAddModel={addModelFromPicker}\n        onDeleteModel={deleteCustomModel}\n        totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}\n        usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}\n        {downloadsData}\n        topologyNodes={data?.nodes}\n      />\n    {/if}\n  {/if}\n\n  <!-- ═══════════════════════════════════════════════════════ -->\n  <!-- MAIN DASHBOARD (always rendered, behind onboarding)    -->\n  <!-- ═══════════════════════════════════════════════════════ -->\n  {#if !topologyOnlyEnabled}\n    <HeaderNav\n      showHome={true}\n      onHome={handleGoHome}\n      showSidebarToggle={true}\n      {sidebarVisible}\n      onToggleSidebar={toggleChatSidebarVisible}\n      showMobileMenuToggle={true}\n      mobileMenuOpen={mobileChatOpen}\n      onToggleMobileMenu={toggleMobileChatSidebar}\n      showMobileRightToggle={!chatStarted && !topologyOnlyEnabled}\n      {mobileRightOpen}\n      onToggleMobileRight={toggleMobileRightSidebar}\n      downloadProgress={activeDownloadSummary}\n    />\n  {/if}\n\n  <!-- Mobile Chat Sidebar Drawer -->\n  {#if !topologyOnlyEnabled}\n    <ChatSidebar\n      isMobileDrawer={true}\n      isOpen={mobileChatOpen}\n      onClose={() => setMobileChatSidebarOpen(false)}\n      onNewChat={handleNewChat}\n      onSelectConversation={() => {\n        userForcedIdle = false;\n      }}\n    />\n  {/if}\n\n  <!-- Main Content -->\n  <main class=\"flex-1 flex overflow-hidden relative\">\n    <!-- Left: Conversation History Sidebar (hidden in topology-only mode, welcome state, or when toggled off) - Desktop only -->\n    {#if !topologyOnlyEnabled && sidebarVisible}\n      <div\n        class=\"hidden md:block w-80 flex-shrink-0 border-r border-exo-yellow/10\"\n        role=\"complementary\"\n        aria-label=\"Conversation history\"\n      >\n        <ChatSidebar\n          class=\"h-full\"\n          onNewChat={handleNewChat}\n          onSelectConversation={() => {\n            userForcedIdle = false;\n          }}\n        />\n      </div>\n    {/if}\n\n    {#if topologyOnlyEnabled}\n      <!-- TOPOLOGY ONLY MODE: Full-screen topology -->\n      <div\n        class=\"flex-1 flex flex-col min-h-0 min-w-0 p-4\"\n        in:fade={{ duration: 300 }}\n      >\n        <div\n          class=\"flex-1 relative bg-exo-dark-gray/40 rounded-lg overflow-hidden\"\n        >\n          <TopologyGraph\n            class=\"w-full h-full\"\n            highlightedNodes={highlightedNodes()}\n            filteredNodes={nodeFilter}\n            onNodeClick={togglePreviewNodeFilter}\n          />\n\n          {@render clusterWarnings()}\n\n          <!-- TB5 RDMA Not Enabled Warning -->\n          {#if tb5WithoutRdma && !tb5InfoDismissed}\n            <div\n              class=\"absolute left-4 group\"\n              class:top-16={tbBridgeCycles.length > 0}\n              class:top-4={tbBridgeCycles.length === 0}\n              role=\"status\"\n            >\n              <div\n                class=\"flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help\"\n              >\n                <svg\n                  class=\"w-5 h-5 text-yellow-400 flex-shrink-0\"\n                  fill=\"none\"\n                  viewBox=\"0 0 24 24\"\n                  stroke=\"currentColor\"\n                  stroke-width=\"2\"\n                >\n                  <path\n                    stroke-linecap=\"round\"\n                    stroke-linejoin=\"round\"\n                    d={warningIconPath}\n                  />\n                </svg>\n                <span class=\"text-sm font-mono text-yellow-200\">\n                  RDMA NOT ENABLED\n                </span>\n                <button\n                  type=\"button\"\n                  onclick={() => (tb5InfoDismissed = true)}\n                  class=\"ml-1 text-yellow-300/60 hover:text-yellow-200 transition-colors cursor-pointer\"\n                  title=\"Dismiss\"\n                >\n                  <svg\n                    class=\"w-4 h-4\"\n                    fill=\"none\"\n                    viewBox=\"0 0 24 24\"\n                    stroke=\"currentColor\"\n                    stroke-width=\"2\"\n                  >\n                    <path\n                      stroke-linecap=\"round\"\n                      stroke-linejoin=\"round\"\n                      d=\"M6 18L18 6M6 6l12 12\"\n                    />\n                  </svg>\n                </button>\n              </div>\n              <!-- Tooltip on hover -->\n              <div\n                class=\"absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg\"\n              >\n                <p class=\"text-xs text-white/80 mb-2\">\n                  Thunderbolt 5 hardware detected on multiple nodes. Enable RDMA\n                  for significantly faster inter-node communication.\n                </p>\n                <p class=\"text-xs text-white/60 mb-1.5\">\n                  <span class=\"text-yellow-300\">To enable:</span>\n                </p>\n                <ol\n                  class=\"text-xs text-white/60 list-decimal list-inside space-y-0.5 mb-1.5\"\n                >\n                  <li>Connect nodes with TB5 cables</li>\n                  <li>Boot to Recovery (hold power 10s → Options)</li>\n                  <li>\n                    Run\n                    <code class=\"text-yellow-300 bg-yellow-400/10 px-1 rounded\"\n                      >rdma_ctl enable</code\n                    >\n                  </li>\n                  <li>Reboot</li>\n                </ol>\n                <p class=\"text-xs text-white/40\">\n                  Requires macOS 26.2+, TB5 cables, and matching OS versions.\n                </p>\n              </div>\n            </div>\n          {/if}\n\n          <!-- Exit topology-only mode button -->\n          <button\n            type=\"button\"\n            onclick={toggleTopologyOnlyMode}\n            class=\"absolute bottom-4 right-4 p-2 rounded border border-exo-yellow/30 bg-exo-dark-gray/80 hover:border-exo-yellow/50 hover:bg-exo-dark-gray transition-colors cursor-pointer backdrop-blur-sm\"\n            title=\"Exit topology only mode\"\n            aria-label=\"Exit topology only mode\"\n          >\n            <svg\n              class=\"w-5 h-5 text-exo-yellow\"\n              fill=\"none\"\n              viewBox=\"0 0 24 24\"\n              stroke=\"currentColor\"\n              stroke-width=\"2\"\n            >\n              <circle cx=\"12\" cy=\"5\" r=\"2\" fill=\"currentColor\" />\n              <circle cx=\"5\" cy=\"19\" r=\"2\" fill=\"currentColor\" />\n              <circle cx=\"19\" cy=\"19\" r=\"2\" fill=\"currentColor\" />\n              <path stroke-linecap=\"round\" d=\"M12 7v5m0 0l-5 5m5-5l5 5\" />\n            </svg>\n          </button>\n        </div>\n      </div>\n    {:else if !chatStarted}\n      <!-- WELCOME STATE: Topology + Instance Controls (no left sidebar for cleaner look) -->\n      <div\n        class=\"flex-1 flex overflow-hidden relative\"\n        in:fade={{ duration: 300 }}\n        out:fade={{ duration: 200 }}\n      >\n        <!-- Center: MAIN TOPOLOGY DISPLAY -->\n        <div class=\"flex-1 flex flex-col min-h-0 min-w-0 py-4\">\n          <!-- Topology Container - Takes most of the space -->\n          <div\n            class=\"flex-1 relative bg-exo-dark-gray/40 mx-4 mb-4 rounded-lg overflow-hidden\"\n          >\n            <!-- The main topology graph - full container -->\n            <TopologyGraph\n              class=\"w-full h-full\"\n              highlightedNodes={highlightedNodes()}\n              filteredNodes={nodeFilter}\n              onNodeClick={togglePreviewNodeFilter}\n            />\n\n            <!-- Initial loading state before first data fetch -->\n            {#if !update}\n              <div\n                class=\"absolute inset-0 flex items-center justify-center bg-exo-dark-gray/80\"\n                in:fade={{ duration: 200 }}\n                out:fade={{ duration: 300 }}\n              >\n                <div class=\"text-center\">\n                  <div\n                    class=\"w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin mx-auto mb-4\"\n                  ></div>\n                  <p\n                    class=\"text-xs font-mono text-white/40 tracking-wider uppercase\"\n                  >\n                    Connecting to cluster&hellip;\n                  </p>\n                </div>\n              </div>\n            {/if}\n\n            {@render clusterWarnings()}\n\n            <!-- TB5 RDMA Not Enabled Warning -->\n            {#if tb5WithoutRdma && !tb5InfoDismissed}\n              <div\n                class=\"absolute left-4 group\"\n                class:top-16={tbBridgeCycles.length > 0}\n                class:top-4={tbBridgeCycles.length === 0}\n                role=\"status\"\n              >\n                <div\n                  class=\"flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help\"\n                >\n                  <svg\n                    class=\"w-5 h-5 text-yellow-400 flex-shrink-0\"\n                    fill=\"none\"\n                    viewBox=\"0 0 24 24\"\n                    stroke=\"currentColor\"\n                    stroke-width=\"2\"\n                  >\n                    <path\n                      stroke-linecap=\"round\"\n                      stroke-linejoin=\"round\"\n                      d={warningIconPath}\n                    />\n                  </svg>\n                  <span class=\"text-sm font-mono text-yellow-200\">\n                    RDMA NOT ENABLED\n                  </span>\n                  <button\n                    type=\"button\"\n                    onclick={() => (tb5InfoDismissed = true)}\n                    class=\"ml-1 text-yellow-300/60 hover:text-yellow-200 transition-colors cursor-pointer\"\n                    title=\"Dismiss\"\n                  >\n                    <svg\n                      class=\"w-4 h-4\"\n                      fill=\"none\"\n                      viewBox=\"0 0 24 24\"\n                      stroke=\"currentColor\"\n                      stroke-width=\"2\"\n                    >\n                      <path\n                        stroke-linecap=\"round\"\n                        stroke-linejoin=\"round\"\n                        d=\"M6 18L18 6M6 6l12 12\"\n                      />\n                    </svg>\n                  </button>\n                </div>\n\n                <!-- Tooltip on hover -->\n                <div\n                  class=\"absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg\"\n                >\n                  <p class=\"text-xs text-white/80 mb-2\">\n                    Thunderbolt 5 hardware detected on multiple nodes. Enable\n                    RDMA for significantly faster inter-node communication.\n                  </p>\n                  <p class=\"text-xs text-white/60 mb-1.5\">\n                    <span class=\"text-yellow-300\">To enable:</span>\n                  </p>\n                  <ol\n                    class=\"text-xs text-white/60 list-decimal list-inside space-y-0.5 mb-1.5\"\n                  >\n                    <li>Connect nodes with TB5 cables</li>\n                    <li>Boot to Recovery (hold power 10s → Options)</li>\n                    <li>\n                      Run\n                      <code\n                        class=\"text-yellow-300 bg-yellow-400/10 px-1 rounded\"\n                        >rdma_ctl enable</code\n                      >\n                    </li>\n                    <li>Reboot</li>\n                  </ol>\n                  <p class=\"text-xs text-white/40\">\n                    Requires macOS 26.2+, TB5 cables, and matching OS versions.\n                  </p>\n                </div>\n              </div>\n            {/if}\n\n            <!-- Node Filter Indicator (top-right corner) -->\n            {#if isFilterActive()}\n              <button\n                onclick={clearPreviewNodeFilter}\n                class=\"absolute top-2 right-2 flex items-center gap-1.5 px-2 py-1 bg-exo-dark-gray/80 border border-exo-yellow/40 rounded text-exo-yellow hover:border-exo-yellow/60 transition-colors cursor-pointer backdrop-blur-sm\"\n                title=\"Clear filter\"\n              >\n                <span class=\"text-[10px] font-mono tracking-wider\">\n                  FILTER: {nodeFilter.size}\n                </span>\n                <svg\n                  class=\"w-3 h-3\"\n                  fill=\"none\"\n                  viewBox=\"0 0 24 24\"\n                  stroke=\"currentColor\"\n                  stroke-width=\"2\"\n                >\n                  <path\n                    stroke-linecap=\"round\"\n                    stroke-linejoin=\"round\"\n                    d=\"M6 18L18 6M6 6l12 12\"\n                  />\n                </svg>\n              </button>\n            {/if}\n          </div>\n\n          <!-- Chat Input - Below topology, never overlaps -->\n          <div class=\"px-4 pt-4 pb-6 flex-shrink-0\">\n            <div class=\"max-w-3xl mx-auto\">\n              {#if instanceCount === 0}\n                <div class=\"text-center mb-4\">\n                  <p class=\"text-sm text-white/50 font-sans\">\n                    Select a model to get started.\n                  </p>\n                </div>\n              {/if}\n              <ChatForm\n                placeholder={instanceCount === 0\n                  ? \"Choose a model to start chatting\"\n                  : \"Ask anything\"}\n                showHelperText={false}\n                showModelSelector={true}\n                modelTasks={modelTasks()}\n                modelCapabilities={modelCapabilities()}\n                onOpenModelPicker={openChatModelPicker}\n                onAutoSend={handleChatSend}\n              />\n            </div>\n          </div>\n        </div>\n\n        <!-- Mobile Right Sidebar Drawer (Instances) -->\n        {#if mobileRightOpen}\n          <!-- Overlay backdrop -->\n          <button\n            type=\"button\"\n            class=\"fixed inset-0 bg-black/60 backdrop-blur-sm z-40 md:hidden\"\n            onclick={() => setMobileRightSidebarOpen(false)}\n            aria-label=\"Close instances panel\"\n          ></button>\n          <!-- Drawer panel -->\n          <aside\n            class=\"fixed right-0 top-0 bottom-0 w-80 bg-exo-dark-gray border-l border-exo-yellow/10 z-50 flex flex-col md:hidden overflow-y-auto\"\n            aria-label=\"Instance controls mobile\"\n          >\n            {@render rightSidebarContent()}\n          </aside>\n        {/if}\n\n        <!-- Right Sidebar: Instance Controls (wider on welcome page for better visibility) - Desktop only -->\n        <aside\n          class=\"hidden md:flex w-80 border-l border-exo-yellow/10 bg-exo-dark-gray flex-col flex-shrink-0\"\n          aria-label=\"Instance controls\"\n        >\n          {@render rightSidebarContent()}\n        </aside>\n\n        {#snippet rightSidebarContent()}\n          <!-- Running Instances Panel (only shown when instances exist) - Scrollable -->\n          {#if instanceCount > 0}\n            <div class=\"p-4 flex-shrink-0\">\n              <!-- Panel Header -->\n              <div class=\"flex items-center gap-2 mb-4\">\n                <div\n                  class=\"w-2 h-2 bg-exo-yellow rounded-full shadow-[0_0_8px_rgba(255,215,0,0.6)] animate-pulse\"\n                ></div>\n                <h3\n                  class=\"text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase\"\n                >\n                  Instances\n                </h3>\n                <div\n                  class=\"flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent\"\n                ></div>\n              </div>\n\n              <div\n                bind:this={instancesContainerRef}\n                class=\"max-h-72 xl:max-h-96 space-y-3 overflow-y-auto overflow-x-hidden py-px\"\n              >\n                {#each Object.entries(instanceData) as [id, instance]}\n                  {@const downloadInfo = getInstanceDownloadStatus(\n                    id,\n                    instance,\n                  )}\n                  {@const statusText = downloadInfo.statusText}\n                  {@const isDownloading = downloadInfo.isDownloading}\n                  {@const isFailed = statusText === \"FAILED\"}\n                  {@const isLoading = statusText === \"LOADING\"}\n                  {@const isWarmingUp =\n                    statusText === \"WARMING UP\" || statusText === \"WAITING\"}\n                  {@const isReady =\n                    statusText === \"READY\" || statusText === \"LOADED\"}\n                  {@const isRunning = statusText === \"RUNNING\"}\n                  <!-- Instance Card -->\n                  {@const instanceModelId = getInstanceModelId(instance)}\n                  {@const instanceInfo = getInstanceInfo(instance)}\n                  {@const instanceConnections =\n                    getInstanceConnections(instance)}\n                  <div\n                    class=\"relative group cursor-pointer\"\n                    role=\"button\"\n                    tabindex=\"0\"\n                    transition:slide={{ duration: 250, easing: cubicOut }}\n                    onmouseenter={() => (hoveredInstanceId = id)}\n                    onmouseleave={() => (hoveredInstanceId = null)}\n                    onclick={() => {\n                      if (\n                        instanceModelId &&\n                        instanceModelId !== \"Unknown\" &&\n                        instanceModelId !== \"Unknown Model\"\n                      ) {\n                        userForcedIdle = false;\n                        setSelectedChatModel(instanceModelId);\n                      }\n                    }}\n                    onkeydown={(e) => {\n                      if (e.key === \"Enter\" || e.key === \" \") {\n                        if (\n                          instanceModelId &&\n                          instanceModelId !== \"Unknown\" &&\n                          instanceModelId !== \"Unknown Model\"\n                        ) {\n                          setSelectedChatModel(instanceModelId);\n                        }\n                      }\n                    }}\n                  >\n                    <!-- Corner accents -->\n                    <div\n                      class=\"absolute -top-px -left-px w-2 h-2 border-l border-t {isDownloading\n                        ? 'border-blue-500/50'\n                        : isFailed\n                          ? 'border-red-500/50'\n                          : isLoading\n                            ? 'border-yellow-500/50'\n                            : isReady\n                              ? 'border-green-500/50'\n                              : 'border-teal-500/50'}\"\n                    ></div>\n                    <div\n                      class=\"absolute -top-px -right-px w-2 h-2 border-r border-t {isDownloading\n                        ? 'border-blue-500/50'\n                        : isFailed\n                          ? 'border-red-500/50'\n                          : isLoading\n                            ? 'border-yellow-500/50'\n                            : isReady\n                              ? 'border-green-500/50'\n                              : 'border-teal-500/50'}\"\n                    ></div>\n                    <div\n                      class=\"absolute -bottom-px -left-px w-2 h-2 border-l border-b {isDownloading\n                        ? 'border-blue-500/50'\n                        : isFailed\n                          ? 'border-red-500/50'\n                          : isLoading\n                            ? 'border-yellow-500/50'\n                            : isReady\n                              ? 'border-green-500/50'\n                              : 'border-teal-500/50'}\"\n                    ></div>\n                    <div\n                      class=\"absolute -bottom-px -right-px w-2 h-2 border-r border-b {isDownloading\n                        ? 'border-blue-500/50'\n                        : isFailed\n                          ? 'border-red-500/50'\n                          : isLoading\n                            ? 'border-yellow-500/50'\n                            : isReady\n                              ? 'border-green-500/50'\n                              : 'border-teal-500/50'}\"\n                    ></div>\n\n                    <div\n                      class=\"bg-exo-dark-gray/60 border border-l-2 transition-all duration-200 group-hover:bg-exo-dark-gray/80 {isDownloading\n                        ? 'border-blue-500/30 border-l-blue-400 group-hover:border-blue-500/50'\n                        : isFailed\n                          ? 'border-red-500/30 border-l-red-400 group-hover:border-red-500/50'\n                          : isLoading\n                            ? 'border-exo-yellow/30 border-l-yellow-400 group-hover:border-exo-yellow/50'\n                            : isReady\n                              ? 'border-green-500/30 border-l-green-400 group-hover:border-green-500/50'\n                              : 'border-teal-500/30 border-l-teal-400 group-hover:border-teal-500/50'} p-3\"\n                    >\n                      <div class=\"flex justify-between items-start mb-2 pl-2\">\n                        <div class=\"flex items-center gap-2\">\n                          <div\n                            class=\"w-1.5 h-1.5 {isDownloading\n                              ? 'bg-blue-400 animate-pulse'\n                              : isFailed\n                                ? 'bg-red-400'\n                                : isLoading\n                                  ? 'bg-yellow-400 animate-pulse'\n                                  : isReady\n                                    ? 'bg-green-400'\n                                    : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]\"\n                          ></div>\n                          <span\n                            class=\"text-exo-light-gray font-mono text-sm tracking-wider\"\n                            >{id.slice(0, 8).toUpperCase()}</span\n                          >\n                        </div>\n                        <button\n                          onclick={() => deleteInstance(id)}\n                          class=\"text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer\"\n                        >\n                          DELETE\n                        </button>\n                      </div>\n                      <div class=\"pl-2\">\n                        <div\n                          class=\"text-exo-yellow text-xs font-mono tracking-wide truncate\"\n                        >\n                          {getInstanceModelId(instance)}\n                        </div>\n                        <div\n                          class=\"flex items-center gap-2 text-white/60 text-xs font-mono\"\n                        >\n                          <span\n                            >{instanceInfo.sharding} &middot; {instanceInfo.instanceType}</span\n                          >\n                          <span\n                            class=\"px-1.5 py-0.5 text-[10px] tracking-wider uppercase rounded transition-all duration-300 {isDownloading\n                              ? 'bg-blue-500/15 text-blue-400'\n                              : isFailed\n                                ? 'bg-red-500/15 text-red-400'\n                                : isLoading\n                                  ? 'bg-yellow-500/15 text-yellow-400'\n                                  : isReady\n                                    ? 'bg-green-500/15 text-green-400'\n                                    : 'bg-teal-500/15 text-teal-400'}\"\n                          >\n                            {statusText}\n                          </span>\n                        </div>\n                        {#if instanceModelId && instanceModelId !== \"Unknown\" && instanceModelId !== \"Unknown Model\"}\n                          <a\n                            class=\"inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1\"\n                            href={`https://huggingface.co/${instanceModelId}`}\n                            target=\"_blank\"\n                            rel=\"noreferrer noopener\"\n                            aria-label=\"View model on Hugging Face\"\n                          >\n                            <span>Hugging Face</span>\n                            <svg\n                              class=\"w-3.5 h-3.5\"\n                              viewBox=\"0 0 24 24\"\n                              fill=\"none\"\n                              stroke=\"currentColor\"\n                              stroke-width=\"2\"\n                              stroke-linecap=\"round\"\n                              stroke-linejoin=\"round\"\n                            >\n                              <path d=\"M14 3h7v7\" />\n                              <path d=\"M10 14l11-11\" />\n                              <path\n                                d=\"M21 14v6a1 1 0 0 1-1 1h-16a1 1 0 0 1-1-1v-16a1 1 0 0 1 1-1h6\"\n                              />\n                            </svg>\n                          </a>\n                        {/if}\n                        {#if instanceInfo.nodeNames.length > 0}\n                          <div class=\"text-white/60 text-xs font-mono\">\n                            {instanceInfo.nodeNames.join(\", \")}\n                          </div>\n                        {/if}\n                        {#if debugEnabled && instanceConnections.length > 0}\n                          <div class=\"mt-2 space-y-1\">\n                            {#each instanceConnections as conn}\n                              <div\n                                class=\"text-[11px] leading-snug font-mono text-white/70\"\n                              >\n                                <span>{conn.from} -> {conn.to}: {conn.ip}</span>\n                                <span\n                                  class={conn.missingIface\n                                    ? \"text-red-400\"\n                                    : \"text-white/60\"}\n                                >\n                                  ({conn.ifaceLabel})</span\n                                >\n                              </div>\n                            {/each}\n                          </div>\n                        {/if}\n\n                        <!-- Download Progress -->\n                        {#if downloadInfo.isDownloading && downloadInfo.progress}\n                          <div class=\"mt-2 space-y-1\">\n                            <div class=\"flex justify-between text-xs font-mono\">\n                              <span class=\"text-blue-400\"\n                                >{downloadInfo.progress.percentage.toFixed(\n                                  1,\n                                )}%</span\n                              >\n                              <span class=\"text-exo-light-gray\"\n                                >{formatBytes(\n                                  downloadInfo.progress.downloadedBytes,\n                                )}/{formatBytes(\n                                  downloadInfo.progress.totalBytes,\n                                )}</span\n                              >\n                            </div>\n                            <div\n                              class=\"relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden\"\n                            >\n                              <div\n                                class=\"absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300\"\n                                style=\"width: {downloadInfo.progress\n                                  .percentage}%\"\n                              ></div>\n                            </div>\n                            <div\n                              class=\"flex justify-between text-xs font-mono text-exo-light-gray\"\n                            >\n                              <span\n                                >{formatSpeed(\n                                  downloadInfo.progress.speed,\n                                )}</span\n                              >\n                              <span\n                                >ETA: {formatEta(\n                                  downloadInfo.progress.etaMs,\n                                )}</span\n                              >\n                              <span\n                                >{downloadInfo.progress\n                                  .completedFiles}/{downloadInfo.progress\n                                  .totalFiles} files</span\n                              >\n                            </div>\n                          </div>\n                          {#if downloadInfo.perNode.length > 0}\n                            <div\n                              class=\"mt-2 space-y-2 max-h-48 overflow-y-auto pr-1\"\n                            >\n                              {#each downloadInfo.perNode.filter((n) => n.status === \"downloading\" && n.progress) as nodeProg}\n                                {@const nodePercent = Math.min(\n                                  100,\n                                  Math.max(0, nodeProg.percentage),\n                                )}\n                                {@const isExpanded =\n                                  instanceDownloadExpandedNodes.has(\n                                    nodeProg.nodeId,\n                                  )}\n                                <div\n                                  class=\"rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2\"\n                                >\n                                  <button\n                                    type=\"button\"\n                                    class=\"w-full text-left space-y-1.5\"\n                                    onclick={() =>\n                                      toggleInstanceDownloadDetails(\n                                        nodeProg.nodeId,\n                                      )}\n                                  >\n                                    <div\n                                      class=\"flex items-center justify-between text-[11px] font-mono text-exo-light-gray\"\n                                    >\n                                      <span class=\"text-white/80 truncate pr-2\"\n                                        >{nodeProg.nodeName}</span\n                                      >\n                                      <span\n                                        class=\"flex items-center gap-1 text-blue-300\"\n                                      >\n                                        {nodePercent.toFixed(1)}%\n                                        <svg\n                                          class=\"w-3 h-3 text-exo-light-gray\"\n                                          viewBox=\"0 0 20 20\"\n                                          fill=\"none\"\n                                          stroke=\"currentColor\"\n                                          stroke-width=\"2\"\n                                        >\n                                          <path\n                                            d=\"M6 8l4 4 4-4\"\n                                            class={isExpanded\n                                              ? \"transform rotate-180 origin-center transition-transform duration-150\"\n                                              : \"transition-transform duration-150\"}\n                                          ></path>\n                                        </svg>\n                                      </span>\n                                    </div>\n                                    <div\n                                      class=\"relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden\"\n                                    >\n                                      <div\n                                        class=\"absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300\"\n                                        style=\"width: {nodePercent.toFixed(1)}%\"\n                                      ></div>\n                                    </div>\n                                    <div\n                                      class=\"flex items-center justify-between text-[11px] font-mono text-exo-light-gray\"\n                                    >\n                                      <span\n                                        >{formatBytes(\n                                          nodeProg.progress?.downloadedBytes ??\n                                            0,\n                                        )} / {formatBytes(\n                                          nodeProg.progress?.totalBytes ?? 0,\n                                        )}</span\n                                      >\n                                      <span\n                                        >{formatSpeed(\n                                          nodeProg.progress?.speed ?? 0,\n                                        )} • ETA {formatEta(\n                                          nodeProg.progress?.etaMs ?? 0,\n                                        )}</span\n                                      >\n                                    </div>\n                                  </button>\n\n                                  {#if isExpanded}\n                                    <div class=\"mt-2 space-y-1.5\">\n                                      {#if nodeProg.progress?.files ?? [].length === 0}\n                                        <div\n                                          class=\"text-[11px] font-mono text-exo-light-gray/70\"\n                                        >\n                                          No file details reported.\n                                        </div>\n                                      {:else}\n                                        {#each nodeProg.progress?.files ?? [] as f}\n                                          {@const filePercent = Math.min(\n                                            100,\n                                            Math.max(0, f.percentage ?? 0),\n                                          )}\n                                          {@const isFileComplete =\n                                            filePercent >= 100}\n                                          <div\n                                            class=\"rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2\"\n                                          >\n                                            <div\n                                              class=\"flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90\"\n                                            >\n                                              <span class=\"truncate pr-2\"\n                                                >{f.name}</span\n                                              >\n                                              <span\n                                                class={isFileComplete\n                                                  ? \"text-green-400\"\n                                                  : \"text-white/80\"}\n                                                >{filePercent.toFixed(1)}%</span\n                                              >\n                                            </div>\n                                            <div\n                                              class=\"relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1\"\n                                            >\n                                              <div\n                                                class=\"absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete\n                                                  ? 'from-green-500 to-green-400'\n                                                  : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300\"\n                                                style=\"width: {filePercent.toFixed(\n                                                  1,\n                                                )}%\"\n                                              ></div>\n                                            </div>\n                                            <div\n                                              class=\"flex items-center justify-between text-[10px] text-exo-light-gray/70 mt-0.5\"\n                                            >\n                                              <span\n                                                >{formatBytes(\n                                                  f.downloadedBytes,\n                                                )} / {formatBytes(\n                                                  f.totalBytes,\n                                                )}</span\n                                              >\n                                              <span\n                                                >{formatSpeed(f.speed)} • ETA {formatEta(\n                                                  f.etaMs,\n                                                )}</span\n                                              >\n                                            </div>\n                                          </div>\n                                        {/each}\n                                      {/if}\n                                    </div>\n                                  {/if}\n                                </div>\n                              {/each}\n                            </div>\n                          {/if}\n                          <div class=\"mt-2 space-y-1\">\n                            <div\n                              class=\"text-xs text-blue-400 font-mono tracking-wider\"\n                            >\n                              DOWNLOADING\n                            </div>\n                            <p\n                              class=\"text-[11px] text-white/50 leading-relaxed\"\n                            >\n                              Downloading model files. Model runs on your\n                              devices so needs to be downloaded before you can\n                              chat.\n                            </p>\n                          </div>\n                        {:else}\n                          <div class=\"mt-1 space-y-1\">\n                            <div\n                              class=\"text-xs {getStatusColor(\n                                downloadInfo.statusText,\n                              )} font-mono tracking-wider\"\n                            >\n                              {downloadInfo.statusText}\n                            </div>\n                            {#if isLoading}\n                              {@const loadStatus =\n                                deriveInstanceStatus(instance)}\n                              {#if loadStatus.totalLayers && loadStatus.totalLayers > 0}\n                                <div class=\"mt-1 space-y-1\">\n                                  <div\n                                    class=\"flex justify-between text-xs font-mono\"\n                                  >\n                                    <span class=\"text-yellow-400\"\n                                      >{(\n                                        ((loadStatus.layersLoaded ?? 0) /\n                                          loadStatus.totalLayers) *\n                                        100\n                                      ).toFixed(0)}%</span\n                                    >\n                                    <span class=\"text-exo-light-gray\"\n                                      >{loadStatus.layersLoaded ?? 0} / {loadStatus.totalLayers}\n                                      layers</span\n                                    >\n                                  </div>\n                                  <div\n                                    class=\"relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden\"\n                                  >\n                                    <div\n                                      class=\"absolute inset-y-0 left-0 bg-gradient-to-r from-yellow-500 to-yellow-400 transition-all duration-300\"\n                                      style=\"width: {((loadStatus.layersLoaded ??\n                                        0) /\n                                        loadStatus.totalLayers) *\n                                        100}%\"\n                                    ></div>\n                                  </div>\n                                </div>\n                              {:else}\n                                <p\n                                  class=\"text-[11px] text-white/50 leading-relaxed\"\n                                >\n                                  Loading model into memory...\n                                </p>\n                              {/if}\n                            {:else if isWarmingUp}\n                              <p\n                                class=\"text-[11px] text-white/50 leading-relaxed\"\n                              >\n                                Warming up...\n                              </p>\n                            {:else if isReady || isRunning}\n                              <p\n                                class=\"text-[11px] text-green-400/70 leading-relaxed\"\n                              >\n                                Ready to chat!\n                              </p>\n                            {/if}\n                          </div>\n                          {#if downloadInfo.isFailed && downloadInfo.errorMessage}\n                            <div\n                              class=\"text-xs text-red-400/80 font-mono mt-1 break-words\"\n                            >\n                              {downloadInfo.errorMessage}\n                            </div>\n                          {/if}\n                        {/if}\n                      </div>\n                    </div>\n                  </div>\n                {/each}\n              </div>\n            </div>\n          {/if}\n\n          <!-- Models Panel - Scrollable -->\n          <div class=\"p-4 flex-1 overflow-y-auto\">\n            <!-- Panel Header -->\n            <div class=\"flex items-center gap-2 mb-3 flex-shrink-0\">\n              <div class=\"w-2 h-2 border border-exo-yellow/60 rotate-45\"></div>\n              <h3\n                class=\"text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase\"\n              >\n                Load Model\n              </h3>\n              <div\n                class=\"flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent\"\n              ></div>\n              <span class=\"text-sm text-white/70 font-mono\"\n                >{models.length} models</span\n              >\n            </div>\n\n            <!-- Model Picker Button -->\n            <div class=\"flex-shrink-0 mb-3\">\n              <button\n                type=\"button\"\n                onclick={() => {\n                  modelPickerContext = \"dashboard\";\n                  isModelPickerOpen = true;\n                }}\n                class=\"w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-2.5 text-sm font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 relative\"\n              >\n                {#if selectedModelId}\n                  {@const foundModel = models.find(\n                    (m) => m.id === selectedModelId,\n                  )}\n                  {#if foundModel}\n                    {@const sizeGB = getModelSizeGB(foundModel)}\n                    <span\n                      class=\"flex items-center justify-between gap-2 w-full pr-4\"\n                    >\n                      <span\n                        class=\"flex items-center gap-2 text-exo-light-gray truncate\"\n                      >\n                        <span class=\"truncate\"\n                          >{foundModel.name || foundModel.id}</span\n                        >\n                      </span>\n                      <span class=\"text-white/50 text-xs flex-shrink-0\"\n                        >{sizeGB >= 1\n                          ? sizeGB.toFixed(0)\n                          : sizeGB.toFixed(1)}GB</span\n                      >\n                    </span>\n                  {:else}\n                    <span class=\"text-exo-light-gray\">{selectedModelId}</span>\n                  {/if}\n                {:else if bestRunningModelId}\n                  {@const runModel = models.find(\n                    (m) => m.id === bestRunningModelId,\n                  )}\n                  {#if runModel}\n                    {@const sizeGB = getModelSizeGB(runModel)}\n                    <span\n                      class=\"flex items-center justify-between gap-2 w-full pr-4\"\n                    >\n                      <span\n                        class=\"flex items-center gap-2 text-exo-light-gray truncate\"\n                      >\n                        <span class=\"truncate\"\n                          >{runModel.name || runModel.id}</span\n                        >\n                      </span>\n                      <span class=\"text-white/50 text-xs flex-shrink-0\"\n                        >{sizeGB >= 1\n                          ? sizeGB.toFixed(0)\n                          : sizeGB.toFixed(1)}GB</span\n                      >\n                    </span>\n                  {:else}\n                    <span class=\"text-exo-light-gray\">{bestRunningModelId}</span\n                    >\n                  {/if}\n                {:else}\n                  <span class=\"text-white/50\">— SELECT MODEL —</span>\n                {/if}\n                <div\n                  class=\"absolute right-3 top-1/2 -translate-y-1/2 pointer-events-none\"\n                >\n                  <svg\n                    class=\"w-4 h-4 text-exo-yellow/60\"\n                    fill=\"none\"\n                    viewBox=\"0 0 24 24\"\n                    stroke=\"currentColor\"\n                  >\n                    <path\n                      stroke-linecap=\"round\"\n                      stroke-linejoin=\"round\"\n                      stroke-width=\"2\"\n                      d=\"M19 9l-7 7-7-7\"\n                    />\n                  </svg>\n                </div>\n              </button>\n            </div>\n\n            <!-- Advanced Options Toggle -->\n            <div class=\"flex-shrink-0 mb-4\">\n              <button\n                type=\"button\"\n                onclick={() => (showAdvancedOptions = !showAdvancedOptions)}\n                class=\"flex items-center gap-2 text-xs text-white/50 hover:text-white/70 font-mono tracking-wider uppercase transition-colors cursor-pointer py-1\"\n                aria-expanded={showAdvancedOptions}\n              >\n                <svg\n                  class=\"w-3 h-3 transition-transform duration-200 {showAdvancedOptions\n                    ? 'rotate-90'\n                    : ''}\"\n                  fill=\"none\"\n                  viewBox=\"0 0 24 24\"\n                  stroke=\"currentColor\"\n                  stroke-width=\"2\"\n                >\n                  <path\n                    stroke-linecap=\"round\"\n                    stroke-linejoin=\"round\"\n                    d=\"M9 5l7 7-7 7\"\n                  />\n                </svg>\n                Advanced Options\n              </button>\n\n              {#if showAdvancedOptions}\n                <div class=\"mt-3 space-y-3 pl-1\" in:fade={{ duration: 150 }}>\n                  <!-- Sharding Strategy -->\n                  <div>\n                    <div class=\"text-xs text-white/50 font-mono mb-2\">\n                      Sharding Strategy:\n                    </div>\n                    <div class=\"flex gap-2\">\n                      <button\n                        onclick={() => {\n                          selectedSharding = \"Pipeline\";\n                          saveLaunchDefaults();\n                        }}\n                        class=\"flex items-center gap-2 py-1.5 px-3 text-xs font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding ===\n                        'Pipeline'\n                          ? 'bg-transparent text-exo-yellow border-exo-yellow'\n                          : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}\"\n                      >\n                        <span\n                          class=\"w-3 h-3 rounded-full border-2 flex items-center justify-center {selectedSharding ===\n                          'Pipeline'\n                            ? 'border-exo-yellow'\n                            : 'border-exo-medium-gray'}\"\n                        >\n                          {#if selectedSharding === \"Pipeline\"}\n                            <span class=\"w-1.5 h-1.5 rounded-full bg-exo-yellow\"\n                            ></span>\n                          {/if}\n                        </span>\n                        Pipeline\n                      </button>\n                      <button\n                        onclick={() => {\n                          selectedSharding = \"Tensor\";\n                          saveLaunchDefaults();\n                        }}\n                        class=\"flex items-center gap-2 py-1.5 px-3 text-xs font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding ===\n                        'Tensor'\n                          ? 'bg-transparent text-exo-yellow border-exo-yellow'\n                          : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}\"\n                      >\n                        <span\n                          class=\"w-3 h-3 rounded-full border-2 flex items-center justify-center {selectedSharding ===\n                          'Tensor'\n                            ? 'border-exo-yellow'\n                            : 'border-exo-medium-gray'}\"\n                        >\n                          {#if selectedSharding === \"Tensor\"}\n                            <span class=\"w-1.5 h-1.5 rounded-full bg-exo-yellow\"\n                            ></span>\n                          {/if}\n                        </span>\n                        Tensor\n                      </button>\n                    </div>\n                  </div>\n\n                  <!-- Interconnect -->\n                  <div>\n                    <div class=\"text-xs text-white/50 font-mono mb-2\">\n                      Interconnect:\n                    </div>\n                    <div class=\"flex gap-2\">\n                      <button\n                        onclick={() => {\n                          selectedInstanceType = \"MlxRing\";\n                          saveLaunchDefaults();\n                        }}\n                        class=\"flex items-center gap-2 py-1.5 px-3 text-xs font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType ===\n                        'MlxRing'\n                          ? 'bg-transparent text-exo-yellow border-exo-yellow'\n                          : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}\"\n                      >\n                        <span\n                          class=\"w-3 h-3 rounded-full border-2 flex items-center justify-center {selectedInstanceType ===\n                          'MlxRing'\n                            ? 'border-exo-yellow'\n                            : 'border-exo-medium-gray'}\"\n                        >\n                          {#if selectedInstanceType === \"MlxRing\"}\n                            <span class=\"w-1.5 h-1.5 rounded-full bg-exo-yellow\"\n                            ></span>\n                          {/if}\n                        </span>\n                        TCP/IP\n                      </button>\n                      <button\n                        onclick={() => {\n                          selectedInstanceType = \"MlxJaccl\";\n                          saveLaunchDefaults();\n                        }}\n                        class=\"flex items-center gap-2 py-1.5 px-3 text-xs font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType ===\n                        'MlxJaccl'\n                          ? 'bg-transparent text-exo-yellow border-exo-yellow'\n                          : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}\"\n                      >\n                        <span\n                          class=\"w-3 h-3 rounded-full border-2 flex items-center justify-center {selectedInstanceType ===\n                          'MlxJaccl'\n                            ? 'border-exo-yellow'\n                            : 'border-exo-medium-gray'}\"\n                        >\n                          {#if selectedInstanceType === \"MlxJaccl\"}\n                            <span class=\"w-1.5 h-1.5 rounded-full bg-exo-yellow\"\n                            ></span>\n                          {/if}\n                        </span>\n                        RDMA (Fast)\n                      </button>\n                    </div>\n                  </div>\n\n                  <!-- Minimum Devices -->\n                  <div>\n                    <div class=\"text-xs text-white/50 font-mono mb-2\">\n                      Minimum Devices:\n                    </div>\n                    <!-- Discrete slider track with drag support -->\n                    <!-- svelte-ignore a11y_no_static_element_interactions -->\n                    <div\n                      bind:this={sliderTrackElement}\n                      class=\"relative h-16 cursor-pointer select-none px-2 pr-6\"\n                      onmousedown={handleSliderMouseDown}\n                      ontouchstart={handleSliderTouchStart}\n                    >\n                      <!-- Track background -->\n                      <div\n                        class=\"absolute top-6 left-0 right-0 h-2 bg-exo-medium-gray/50 rounded-full\"\n                      ></div>\n                      <!-- Active track (fills up to selected) -->\n                      {#if availableMinNodes > 1}\n                        <div\n                          class=\"absolute top-6 left-0 h-2 bg-white/30 rounded-full transition-all pointer-events-none\"\n                          style=\"width: {((selectedMinNodes - 1) /\n                            (availableMinNodes - 1)) *\n                            100}%\"\n                        ></div>\n                      {/if}\n                      <!-- Dots and labels for each device count -->\n                      {#each Array.from({ length: availableMinNodes }, (_, i) => i + 1) as n}\n                        {@const isValid = validMinNodeCounts().has(n)}\n                        {@const isSelected = selectedMinNodes === n}\n                        {@const position =\n                          availableMinNodes > 1\n                            ? ((n - 1) / (availableMinNodes - 1)) * 100\n                            : 50}\n                        <div\n                          class=\"absolute flex flex-col items-center pointer-events-none\"\n                          style=\"left: {position}%; top: 0; transform: translateX(-50%);\"\n                        >\n                          <span\n                            class=\"rounded-full transition-all {isSelected\n                              ? 'w-6 h-6 bg-exo-yellow shadow-[0_0_10px_rgba(255,215,0,0.6)]'\n                              : isValid\n                                ? 'w-4 h-4 bg-exo-light-gray/70 mt-1'\n                                : 'w-3 h-3 bg-exo-medium-gray/50 mt-1.5'}\"\n                          ></span>\n                          <span\n                            class=\"text-sm font-mono mt-1.5 tabular-nums transition-colors {isSelected\n                              ? 'text-exo-yellow font-bold'\n                              : isValid\n                                ? 'text-white/70'\n                                : 'text-white/30'}\">{n}</span\n                          >\n                        </div>\n                      {/each}\n                    </div>\n                  </div>\n                </div>\n              {/if}\n            </div>\n\n            <!-- Selected Model Preview -->\n            <div class=\"space-y-3\">\n              {#if models.length === 0}\n                <div class=\"text-center py-8\">\n                  <div\n                    class=\"text-xs text-white/70 font-mono tracking-wider uppercase\"\n                  >\n                    Loading models...\n                  </div>\n                </div>\n              {:else if loadingPreviews}\n                <div class=\"text-center py-8\">\n                  <div\n                    class=\"text-xs text-exo-yellow font-mono tracking-wider uppercase animate-pulse\"\n                  >\n                    Loading preview...\n                  </div>\n                </div>\n              {:else}\n                {@const selectedModel = models.find(\n                  (m) => m.id === selectedModelId,\n                )}\n                {@const allPreviews = filteredPreviews()}\n                {#if selectedModel && allPreviews.length > 0}\n                  {@const tags = modelTags()[selectedModel.id] || []}\n                  <div class=\"space-y-3\">\n                    {#each allPreviews as apiPreview, i}\n                      {@const downloadStatus = getModelDownloadStatus(\n                        selectedModel.id,\n                        apiPreview.memory_delta_by_node\n                          ? Object.keys(apiPreview.memory_delta_by_node)\n                          : undefined,\n                      )}\n                      <div\n                        role=\"group\"\n                        onmouseenter={() => {\n                          if (apiPreview.memory_delta_by_node) {\n                            hoveredPreviewNodes = new Set(\n                              Object.entries(apiPreview.memory_delta_by_node)\n                                .filter(([, delta]) => (delta ?? 0) > 0)\n                                .map(([nodeId]) => nodeId),\n                            );\n                          }\n                        }}\n                        onmouseleave={() => (hoveredPreviewNodes = new Set())}\n                      >\n                        <ModelCard\n                          model={selectedModel}\n                          isLaunching={launchingModelId === selectedModel.id}\n                          {downloadStatus}\n                          nodes={data?.nodes ?? {}}\n                          sharding={apiPreview.sharding}\n                          runtime={apiPreview.instance_meta}\n                          onLaunch={() =>\n                            launchInstance(selectedModel.id, apiPreview)}\n                          {tags}\n                          {apiPreview}\n                          modelIdOverride={apiPreview.model_id}\n                        />\n                      </div>\n                    {/each}\n                  </div>\n                {:else if selectedModel}\n                  <div class=\"text-center py-4\">\n                    <div class=\"text-xs text-white/50 font-mono\">\n                      No valid configurations for current settings\n                    </div>\n                  </div>\n                {/if}\n              {/if}\n            </div>\n          </div>\n        {/snippet}\n      </div>\n    {:else}\n      <!-- CHAT STATE: Chat + Mini-Map -->\n      <div class=\"flex-1 flex overflow-hidden\">\n        <!-- Chat Area -->\n        <div\n          class=\"flex-1 flex flex-col min-w-0 overflow-hidden\"\n          in:fade={{ duration: 300, delay: 100 }}\n        >\n          {#if chatLaunchState !== \"idle\" && chatLaunchState !== \"ready\"}\n            <!-- Model launching/downloading/loading: show progress -->\n            <div class=\"flex-1 flex items-center justify-center px-8 py-6\">\n              <div class=\"flex flex-col items-center gap-6 max-w-md w-full\">\n                <!-- Model name -->\n                {#if pendingChatModelId}\n                  <p class=\"text-sm text-white font-mono tracking-wide\">\n                    {pendingChatModelId.split(\"/\").pop()?.replace(/-/g, \" \") ||\n                      pendingChatModelId}\n                  </p>\n                {/if}\n\n                {#if chatLaunchState === \"launching\"}\n                  <div class=\"flex flex-col items-center gap-3\">\n                    <div\n                      class=\"w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin\"\n                    ></div>\n                    <p\n                      class=\"text-xs text-exo-light-gray font-mono uppercase tracking-wider\"\n                    >\n                      Preparing to launch&hellip;\n                    </p>\n                  </div>\n                {:else if chatLaunchState === \"downloading\"}\n                  <div class=\"w-full flex flex-col gap-3\">\n                    <div\n                      class=\"flex items-center justify-between text-xs font-mono\"\n                    >\n                      <span class=\"text-exo-yellow uppercase tracking-wider\"\n                        >Downloading</span\n                      >\n                      {#if chatLaunchDownload}\n                        <span class=\"text-exo-light-gray tabular-nums\">\n                          {chatLaunchDownload.percentage.toFixed(1)}%\n                        </span>\n                      {/if}\n                    </div>\n                    <div\n                      class=\"w-full h-2 bg-exo-dark-gray rounded-full overflow-hidden border border-exo-medium-gray/30\"\n                    >\n                      <div\n                        class=\"h-full bg-gradient-to-r from-exo-yellow/80 to-exo-yellow rounded-full transition-all duration-300\"\n                        style=\"width: {chatLaunchDownload?.percentage ?? 0}%\"\n                      ></div>\n                    </div>\n                    {#if chatLaunchDownload}\n                      <div\n                        class=\"flex justify-between text-[10px] text-exo-light-gray/60 font-mono\"\n                      >\n                        <span\n                          >{formatBytes(chatLaunchDownload.downloadedBytes)} / {formatBytes(\n                            chatLaunchDownload.totalBytes,\n                          )}</span\n                        >\n                        <span>\n                          {#if chatLaunchDownload.speed > 0}\n                            {formatBytes(chatLaunchDownload.speed)}/s\n                          {/if}\n                          {#if chatLaunchDownload.etaMs > 0}\n                            &middot; {formatEta(chatLaunchDownload.etaMs)}\n                          {/if}\n                        </span>\n                      </div>\n                    {/if}\n                  </div>\n                {:else if chatLaunchState === \"loading\"}\n                  <div class=\"w-full flex flex-col gap-3\">\n                    <div\n                      class=\"flex items-center justify-between text-xs font-mono\"\n                    >\n                      <span class=\"text-exo-yellow uppercase tracking-wider\"\n                        >Loading model</span\n                      >\n                      {#if chatLaunchLoadProgress}\n                        <span class=\"text-exo-light-gray tabular-nums\">\n                          {chatLaunchLoadProgress.layersLoaded}/{chatLaunchLoadProgress.totalLayers}\n                          layers\n                        </span>\n                      {/if}\n                    </div>\n                    <div\n                      class=\"w-full h-2 bg-exo-dark-gray rounded-full overflow-hidden border border-exo-medium-gray/30\"\n                    >\n                      <div\n                        class=\"h-full bg-gradient-to-r from-exo-yellow/80 to-exo-yellow rounded-full transition-all duration-300\"\n                        style=\"width: {chatLaunchLoadProgress?.percentage ??\n                          0}%\"\n                      ></div>\n                    </div>\n                  </div>\n                {/if}\n              </div>\n            </div>\n            <div\n              class=\"flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent\"\n            >\n              <div class=\"max-w-7xl mx-auto\">\n                <ChatForm\n                  placeholder=\"Ask anything\"\n                  showModelSelector={true}\n                  modelTasks={modelTasks()}\n                  modelCapabilities={modelCapabilities()}\n                  onAutoSend={handleChatSend}\n                  onOpenModelPicker={openChatModelPicker}\n                />\n              </div>\n            </div>\n          {:else if messages().length > 0 || chatLaunchState === \"ready\"}\n            <!-- Normal chat: show messages -->\n            <div\n              class=\"flex-1 overflow-y-auto px-8 py-6\"\n              bind:this={chatScrollRef}\n              role=\"log\"\n              aria-live=\"polite\"\n              aria-label=\"Chat messages\"\n            >\n              <div class=\"max-w-7xl mx-auto\">\n                <ChatMessages scrollParent={chatScrollRef} />\n                {#if chatLaunchState === \"ready\" && selectedChatCategory}\n                  {@const prompts =\n                    categorySuggestedPrompts[selectedChatCategory] ??\n                    categorySuggestedPrompts.auto}\n                  <div\n                    class=\"flex flex-col items-center gap-4 mt-12\"\n                    in:fade={{ duration: 300 }}\n                  >\n                    <p\n                      class=\"text-xs text-exo-light-gray/60 font-mono uppercase tracking-wider\"\n                    >\n                      Try asking\n                    </p>\n                    <div class=\"grid grid-cols-2 gap-2 max-w-lg w-full\">\n                      {#each prompts as prompt}\n                        <button\n                          type=\"button\"\n                          onclick={() => {\n                            chatLaunchState = \"idle\";\n                            selectedChatCategory = null;\n                            sendMessage(prompt);\n                          }}\n                          class=\"text-left px-3 py-2.5 text-xs text-exo-light-gray hover:text-white font-mono rounded-lg border border-exo-medium-gray/30 hover:border-exo-yellow/30 bg-exo-dark-gray/30 hover:bg-exo-dark-gray/60 transition-all duration-200 cursor-pointer\"\n                        >\n                          {prompt}\n                        </button>\n                      {/each}\n                    </div>\n                  </div>\n                {/if}\n              </div>\n            </div>\n            <div\n              class=\"flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent\"\n            >\n              <div class=\"max-w-7xl mx-auto\">\n                <ChatForm\n                  placeholder=\"Ask anything\"\n                  showModelSelector={true}\n                  modelTasks={modelTasks()}\n                  modelCapabilities={modelCapabilities()}\n                  onAutoSend={handleChatSend}\n                  onOpenModelPicker={openChatModelPicker}\n                />\n              </div>\n            </div>\n          {:else}\n            <!-- No running instance, no messages: show model selector -->\n            <div\n              class=\"flex-1 overflow-y-auto flex items-center justify-center px-8 py-6\"\n            >\n              <ChatModelSelector\n                models={models.map((m) => ({\n                  id: m.id,\n                  name: m.name ?? \"\",\n                  base_model: m.base_model ?? \"\",\n                  storage_size_megabytes: m.storage_size_megabytes ?? 0,\n                  capabilities: m.capabilities ?? [],\n                  family: m.family ?? \"\",\n                  quantization: m.quantization ?? \"\",\n                }))}\n                clusterLabel={chatClusterLabel}\n                totalMemoryGB={availableMemoryGB()}\n                onSelect={handleChatModelSelect}\n                onAddModel={handleChatAddModel}\n              />\n            </div>\n            <div\n              class=\"flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent\"\n            >\n              <div class=\"max-w-7xl mx-auto\">\n                <ChatForm\n                  placeholder=\"Ask anything — we'll pick the best model automatically\"\n                  showModelSelector={!!bestRunningModelId}\n                  modelDisplayOverride={bestRunningModelId ?? undefined}\n                  modelTasks={modelTasks()}\n                  modelCapabilities={modelCapabilities()}\n                  onAutoSend={handleAutoSend}\n                  onOpenModelPicker={openChatModelPicker}\n                />\n              </div>\n            </div>\n          {/if}\n        </div>\n\n        <!-- Right: Mini-Map Sidebar - Desktop only -->\n        {#if minimized}\n          <aside\n            class=\"hidden md:flex w-80 border-l border-exo-yellow/20 bg-exo-dark-gray flex-col flex-shrink-0 overflow-y-auto\"\n            in:fly={{ x: 100, duration: 400, easing: cubicInOut }}\n            aria-label=\"Cluster topology\"\n          >\n            <!-- Topology Section - clickable to go back to main view -->\n            <button\n              class=\"p-4 border-b border-exo-medium-gray/30 w-full text-left cursor-pointer hover:bg-exo-medium-gray/10 transition-colors\"\n              onclick={handleGoHome}\n              title=\"Click to return to main topology view\"\n            >\n              <div class=\"flex items-center justify-between mb-3\">\n                <div\n                  class=\"text-xs text-exo-yellow tracking-[0.2em] uppercase flex items-center gap-2\"\n                >\n                  <span\n                    class=\"w-1.5 h-1.5 bg-exo-yellow rounded-full status-pulse\"\n                  ></span>\n                  TOPOLOGY\n                </div>\n                <span class=\"text-xs text-white/70 tabular-nums\"\n                  >{nodeCount} {nodeCount === 1 ? \"NODE\" : \"NODES\"}</span\n                >\n              </div>\n\n              <div\n                class=\"relative aspect-square bg-exo-dark-gray rounded-lg overflow-hidden pointer-events-none\"\n              >\n                <TopologyGraph\n                  highlightedNodes={highlightedNodes()}\n                  filteredNodes={nodeFilter}\n                />\n\n                {@render clusterWarningsCompact()}\n              </div>\n            </button>\n\n            <!-- Instances Section (only shown when instances exist) -->\n            {#if instanceCount > 0}\n              <div class=\"p-4 flex-1\">\n                <!-- Panel Header -->\n                <div class=\"flex items-center gap-2 mb-4\">\n                  <div\n                    class=\"w-2 h-2 bg-exo-yellow rounded-full shadow-[0_0_8px_rgba(255,215,0,0.6)] animate-pulse\"\n                  ></div>\n                  <h3\n                    class=\"text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase\"\n                  >\n                    Instances\n                  </h3>\n                  <div\n                    class=\"flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent\"\n                  ></div>\n                </div>\n                <div\n                  class=\"space-y-3 max-h-72 xl:max-h-96 overflow-y-auto overflow-x-hidden py-px pr-1\"\n                >\n                  {#each Object.entries(instanceData) as [id, instance]}\n                    {@const downloadInfo = getInstanceDownloadStatus(\n                      id,\n                      instance,\n                    )}\n                    {@const statusText = downloadInfo.statusText}\n                    {@const isDownloading = downloadInfo.isDownloading}\n                    {@const isFailed = statusText === \"FAILED\"}\n                    {@const isLoading = statusText === \"LOADING\"}\n                    {@const isWarmingUp =\n                      statusText === \"WARMING UP\" || statusText === \"WAITING\"}\n                    {@const isReady =\n                      statusText === \"READY\" || statusText === \"LOADED\"}\n                    {@const isRunning = statusText === \"RUNNING\"}\n                    <!-- Instance Card -->\n                    {@const instanceModelId = getInstanceModelId(instance)}\n                    {@const instanceInfo = getInstanceInfo(instance)}\n                    {@const instanceConnections =\n                      getInstanceConnections(instance)}\n                    <div\n                      class=\"relative group cursor-pointer\"\n                      role=\"button\"\n                      tabindex=\"0\"\n                      onmouseenter={() => (hoveredInstanceId = id)}\n                      onmouseleave={() => (hoveredInstanceId = null)}\n                      onclick={() => {\n                        if (\n                          instanceModelId &&\n                          instanceModelId !== \"Unknown\" &&\n                          instanceModelId !== \"Unknown Model\"\n                        ) {\n                          userForcedIdle = false;\n                          setSelectedChatModel(instanceModelId);\n                        }\n                      }}\n                      onkeydown={(e) => {\n                        if (e.key === \"Enter\" || e.key === \" \") {\n                          if (\n                            instanceModelId &&\n                            instanceModelId !== \"Unknown\" &&\n                            instanceModelId !== \"Unknown Model\"\n                          ) {\n                            setSelectedChatModel(instanceModelId);\n                          }\n                        }\n                      }}\n                    >\n                      <!-- Corner accents -->\n                      <div\n                        class=\"absolute -top-px -left-px w-2 h-2 border-l border-t {isDownloading\n                          ? 'border-blue-500/50'\n                          : isFailed\n                            ? 'border-red-500/50'\n                            : isLoading\n                              ? 'border-yellow-500/50'\n                              : isReady\n                                ? 'border-green-500/50'\n                                : 'border-teal-500/50'}\"\n                      ></div>\n                      <div\n                        class=\"absolute -top-px -right-px w-2 h-2 border-r border-t {isDownloading\n                          ? 'border-blue-500/50'\n                          : isFailed\n                            ? 'border-red-500/50'\n                            : isLoading\n                              ? 'border-yellow-500/50'\n                              : isReady\n                                ? 'border-green-500/50'\n                                : 'border-teal-500/50'}\"\n                      ></div>\n                      <div\n                        class=\"absolute -bottom-px -left-px w-2 h-2 border-l border-b {isDownloading\n                          ? 'border-blue-500/50'\n                          : isFailed\n                            ? 'border-red-500/50'\n                            : isLoading\n                              ? 'border-yellow-500/50'\n                              : isReady\n                                ? 'border-green-500/50'\n                                : 'border-teal-500/50'}\"\n                      ></div>\n                      <div\n                        class=\"absolute -bottom-px -right-px w-2 h-2 border-r border-b {isDownloading\n                          ? 'border-blue-500/50'\n                          : isFailed\n                            ? 'border-red-500/50'\n                            : isLoading\n                              ? 'border-yellow-500/50'\n                              : isReady\n                                ? 'border-green-500/50'\n                                : 'border-teal-500/50'}\"\n                      ></div>\n\n                      <div\n                        class=\"bg-exo-dark-gray/60 border border-l-2 {isDownloading\n                          ? 'border-blue-500/30 border-l-blue-400'\n                          : isFailed\n                            ? 'border-red-500/30 border-l-red-400'\n                            : isLoading\n                              ? 'border-exo-yellow/30 border-l-yellow-400'\n                              : isReady\n                                ? 'border-green-500/30 border-l-green-400'\n                                : 'border-teal-500/30 border-l-teal-400'} p-3\"\n                      >\n                        <div class=\"flex justify-between items-start mb-2 pl-2\">\n                          <div class=\"flex items-center gap-2\">\n                            <div\n                              class=\"w-1.5 h-1.5 {isDownloading\n                                ? 'bg-blue-400 animate-pulse'\n                                : isFailed\n                                  ? 'bg-red-400'\n                                  : isLoading\n                                    ? 'bg-yellow-400 animate-pulse'\n                                    : isReady\n                                      ? 'bg-green-400'\n                                      : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]\"\n                            ></div>\n                            <span\n                              class=\"text-exo-light-gray font-mono text-sm tracking-wider\"\n                              >{id.slice(0, 8).toUpperCase()}</span\n                            >\n                          </div>\n                          <button\n                            onclick={() => deleteInstance(id)}\n                            class=\"text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer\"\n                          >\n                            DELETE\n                          </button>\n                        </div>\n                        <div class=\"pl-2\">\n                          <div\n                            class=\"text-exo-yellow text-xs font-mono tracking-wide truncate\"\n                          >\n                            {getInstanceModelId(instance)}\n                          </div>\n                          <div\n                            class=\"flex items-center gap-2 text-white/60 text-xs font-mono\"\n                          >\n                            <span\n                              >{instanceInfo.sharding} &middot; {instanceInfo.instanceType}</span\n                            >\n                            <span\n                              class=\"px-1.5 py-0.5 text-[10px] tracking-wider uppercase rounded transition-all duration-300 {isDownloading\n                                ? 'bg-blue-500/15 text-blue-400'\n                                : isFailed\n                                  ? 'bg-red-500/15 text-red-400'\n                                  : isLoading\n                                    ? 'bg-yellow-500/15 text-yellow-400'\n                                    : isReady\n                                      ? 'bg-green-500/15 text-green-400'\n                                      : 'bg-teal-500/15 text-teal-400'}\"\n                            >\n                              {statusText}\n                            </span>\n                          </div>\n                          {#if instanceModelId && instanceModelId !== \"Unknown\" && instanceModelId !== \"Unknown Model\"}\n                            <a\n                              class=\"inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1\"\n                              href={`https://huggingface.co/${instanceModelId}`}\n                              target=\"_blank\"\n                              rel=\"noreferrer noopener\"\n                              aria-label=\"View model on Hugging Face\"\n                            >\n                              <span>Hugging Face</span>\n                              <svg\n                                class=\"w-3.5 h-3.5\"\n                                viewBox=\"0 0 24 24\"\n                                fill=\"none\"\n                                stroke=\"currentColor\"\n                                stroke-width=\"2\"\n                                stroke-linecap=\"round\"\n                                stroke-linejoin=\"round\"\n                              >\n                                <path d=\"M14 3h7v7\" />\n                                <path d=\"M10 14l11-11\" />\n                                <path\n                                  d=\"M21 14v6a1 1 0 0 1-1 1h-16a1 1 0 0 1-1-1v-16a1 1 0 0 1 1-1h6\"\n                                />\n                              </svg>\n                            </a>\n                          {/if}\n                          {#if instanceInfo.nodeNames.length > 0}\n                            <div class=\"text-white/60 text-xs font-mono\">\n                              {instanceInfo.nodeNames.join(\", \")}\n                            </div>\n                          {/if}\n                          {#if debugEnabled && instanceConnections.length > 0}\n                            <div class=\"mt-2 space-y-1\">\n                              {#each instanceConnections as conn}\n                                <div\n                                  class=\"text-[11px] leading-snug font-mono text-white/70\"\n                                >\n                                  <span\n                                    >{conn.from} -> {conn.to}: {conn.ip}</span\n                                  >\n                                  <span\n                                    class={conn.missingIface\n                                      ? \"text-red-400\"\n                                      : \"text-white/60\"}\n                                  >\n                                    ({conn.ifaceLabel})</span\n                                  >\n                                </div>\n                              {/each}\n                            </div>\n                          {/if}\n\n                          <!-- Download Progress -->\n                          {#if downloadInfo.isDownloading && downloadInfo.progress}\n                            <div class=\"mt-2 space-y-1\">\n                              <div\n                                class=\"flex justify-between text-xs font-mono\"\n                              >\n                                <span class=\"text-blue-400\"\n                                  >{downloadInfo.progress.percentage.toFixed(\n                                    1,\n                                  )}%</span\n                                >\n                                <span class=\"text-exo-light-gray\"\n                                  >{formatBytes(\n                                    downloadInfo.progress.downloadedBytes,\n                                  )}/{formatBytes(\n                                    downloadInfo.progress.totalBytes,\n                                  )}</span\n                                >\n                              </div>\n                              <div\n                                class=\"relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden\"\n                              >\n                                <div\n                                  class=\"absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300\"\n                                  style=\"width: {downloadInfo.progress\n                                    .percentage}%\"\n                                ></div>\n                              </div>\n                              <div\n                                class=\"flex justify-between text-xs font-mono text-exo-light-gray\"\n                              >\n                                <span\n                                  >{formatSpeed(\n                                    downloadInfo.progress.speed,\n                                  )}</span\n                                >\n                                <span\n                                  >ETA: {formatEta(\n                                    downloadInfo.progress.etaMs,\n                                  )}</span\n                                >\n                                <span\n                                  >{downloadInfo.progress\n                                    .completedFiles}/{downloadInfo.progress\n                                    .totalFiles} files</span\n                                >\n                              </div>\n                            </div>\n                            {#if downloadInfo.perNode.length > 0}\n                              <div\n                                class=\"mt-2 space-y-2 max-h-48 overflow-y-auto pr-1\"\n                              >\n                                {#each downloadInfo.perNode.filter((n) => n.status === \"downloading\" && n.progress) as nodeProg}\n                                  {@const nodePercent = Math.min(\n                                    100,\n                                    Math.max(0, nodeProg.percentage),\n                                  )}\n                                  {@const isExpanded =\n                                    instanceDownloadExpandedNodes.has(\n                                      nodeProg.nodeId,\n                                    )}\n                                  <div\n                                    class=\"rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2\"\n                                  >\n                                    <button\n                                      type=\"button\"\n                                      class=\"w-full text-left space-y-1.5\"\n                                      onclick={() =>\n                                        toggleInstanceDownloadDetails(\n                                          nodeProg.nodeId,\n                                        )}\n                                    >\n                                      <div\n                                        class=\"flex items-center justify-between text-[11px] font-mono text-exo-light-gray\"\n                                      >\n                                        <span\n                                          class=\"text-white/80 truncate pr-2\"\n                                          >{nodeProg.nodeName}</span\n                                        >\n                                        <span\n                                          class=\"flex items-center gap-1 text-blue-300\"\n                                        >\n                                          {nodePercent.toFixed(1)}%\n                                          <svg\n                                            class=\"w-3 h-3 text-exo-light-gray\"\n                                            viewBox=\"0 0 20 20\"\n                                            fill=\"none\"\n                                            stroke=\"currentColor\"\n                                            stroke-width=\"2\"\n                                          >\n                                            <path\n                                              d=\"M6 8l4 4 4-4\"\n                                              class={isExpanded\n                                                ? \"transform rotate-180 origin-center transition-transform duration-150\"\n                                                : \"transition-transform duration-150\"}\n                                            ></path>\n                                          </svg>\n                                        </span>\n                                      </div>\n                                      <div\n                                        class=\"relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden\"\n                                      >\n                                        <div\n                                          class=\"absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300\"\n                                          style=\"width: {nodePercent.toFixed(\n                                            1,\n                                          )}%\"\n                                        ></div>\n                                      </div>\n                                      <div\n                                        class=\"flex items-center justify-between text-[11px] font-mono text-exo-light-gray\"\n                                      >\n                                        <span\n                                          >{formatBytes(\n                                            nodeProg.progress\n                                              ?.downloadedBytes ?? 0,\n                                          )} / {formatBytes(\n                                            nodeProg.progress?.totalBytes ?? 0,\n                                          )}</span\n                                        >\n                                        <span\n                                          >{formatSpeed(\n                                            nodeProg.progress?.speed ?? 0,\n                                          )} • ETA {formatEta(\n                                            nodeProg.progress?.etaMs ?? 0,\n                                          )}</span\n                                        >\n                                      </div>\n                                    </button>\n\n                                    {#if isExpanded}\n                                      <div class=\"mt-2 space-y-1.5\">\n                                        {#if nodeProg.progress?.files ?? [].length === 0}\n                                          <div\n                                            class=\"text-[11px] font-mono text-exo-light-gray/70\"\n                                          >\n                                            No file details reported.\n                                          </div>\n                                        {:else}\n                                          {#each nodeProg.progress?.files ?? [] as f}\n                                            {@const filePercent = Math.min(\n                                              100,\n                                              Math.max(0, f.percentage ?? 0),\n                                            )}\n                                            {@const isFileComplete =\n                                              filePercent >= 100}\n                                            <div\n                                              class=\"rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2\"\n                                            >\n                                              <div\n                                                class=\"flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90\"\n                                              >\n                                                <span class=\"truncate pr-2\"\n                                                  >{f.name}</span\n                                                >\n                                                <span\n                                                  class={isFileComplete\n                                                    ? \"text-green-400\"\n                                                    : \"text-white/80\"}\n                                                  >{filePercent.toFixed(\n                                                    1,\n                                                  )}%</span\n                                                >\n                                              </div>\n                                              <div\n                                                class=\"relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1\"\n                                              >\n                                                <div\n                                                  class=\"absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete\n                                                    ? 'from-green-500 to-green-400'\n                                                    : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300\"\n                                                  style=\"width: {filePercent.toFixed(\n                                                    1,\n                                                  )}%\"\n                                                ></div>\n                                              </div>\n                                              <div\n                                                class=\"flex items-center justify-between text-[10px] text-exo-light-gray/70 mt-0.5\"\n                                              >\n                                                <span\n                                                  >{formatBytes(\n                                                    f.downloadedBytes,\n                                                  )} / {formatBytes(\n                                                    f.totalBytes,\n                                                  )}</span\n                                                >\n                                                <span\n                                                  >{formatSpeed(f.speed)} • ETA\n                                                  {formatEta(f.etaMs)}</span\n                                                >\n                                              </div>\n                                            </div>\n                                          {/each}\n                                        {/if}\n                                      </div>\n                                    {/if}\n                                  </div>\n                                {/each}\n                              </div>\n                            {/if}\n                            <div class=\"mt-2 space-y-1\">\n                              <div\n                                class=\"text-xs text-blue-400 font-mono tracking-wider\"\n                              >\n                                DOWNLOADING\n                              </div>\n                              <p\n                                class=\"text-[11px] text-white/50 leading-relaxed\"\n                              >\n                                Downloading model files. Model runs on your\n                                devices so needs to be downloaded before you can\n                                chat.\n                              </p>\n                            </div>\n                          {:else}\n                            <div class=\"mt-1 space-y-1\">\n                              <div\n                                class=\"text-xs {getStatusColor(\n                                  downloadInfo.statusText,\n                                )} font-mono tracking-wider\"\n                              >\n                                {downloadInfo.statusText}\n                              </div>\n                              {#if isLoading}\n                                {@const loadStatus =\n                                  deriveInstanceStatus(instance)}\n                                {#if loadStatus.totalLayers && loadStatus.totalLayers > 0}\n                                  <div class=\"mt-1 space-y-1\">\n                                    <div\n                                      class=\"flex justify-between text-xs font-mono\"\n                                    >\n                                      <span class=\"text-yellow-400\"\n                                        >{(\n                                          ((loadStatus.layersLoaded ?? 0) /\n                                            loadStatus.totalLayers) *\n                                          100\n                                        ).toFixed(0)}%</span\n                                      >\n                                      <span class=\"text-exo-light-gray\"\n                                        >{loadStatus.layersLoaded ?? 0} / {loadStatus.totalLayers}\n                                        layers</span\n                                      >\n                                    </div>\n                                    <div\n                                      class=\"relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden\"\n                                    >\n                                      <div\n                                        class=\"absolute inset-y-0 left-0 bg-gradient-to-r from-yellow-500 to-yellow-400 transition-all duration-300\"\n                                        style=\"width: {((loadStatus.layersLoaded ??\n                                          0) /\n                                          loadStatus.totalLayers) *\n                                          100}%\"\n                                      ></div>\n                                    </div>\n                                  </div>\n                                {:else}\n                                  <p\n                                    class=\"text-[11px] text-white/50 leading-relaxed\"\n                                  >\n                                    Loading model into memory...\n                                  </p>\n                                {/if}\n                              {:else if isWarmingUp}\n                                <p\n                                  class=\"text-[11px] text-white/50 leading-relaxed\"\n                                >\n                                  Warming up...\n                                </p>\n                              {:else if isReady || isRunning}\n                                <p\n                                  class=\"text-[11px] text-green-400/70 leading-relaxed\"\n                                >\n                                  Ready to chat!\n                                </p>\n                              {/if}\n                            </div>\n                            {#if downloadInfo.isFailed && downloadInfo.errorMessage}\n                              <div\n                                class=\"text-xs text-red-400/80 font-mono mt-1 break-words\"\n                              >\n                                {downloadInfo.errorMessage}\n                              </div>\n                            {/if}\n                          {/if}\n                        </div>\n                      </div>\n                    </div>\n                  {/each}\n                </div>\n              </div>\n            {/if}\n          </aside>\n        {/if}\n      </div>\n    {/if}\n  </main>\n</div>\n\n{#if !showOnboarding}\n  <ModelPickerModal\n    isOpen={isModelPickerOpen}\n    {models}\n    {selectedModelId}\n    favorites={favoritesSet}\n    {recentModelIds}\n    hasRecents={showRecentsTab}\n    existingModelIds={new Set(models.map((m) => m.id))}\n    canModelFit={(modelId) => {\n      const model = models.find((m) => m.id === modelId);\n      return model ? hasEnoughMemory(model) : false;\n    }}\n    getModelFitStatus={(modelId): ModelMemoryFitStatus => {\n      const model = models.find((m) => m.id === modelId);\n      return model ? getModelMemoryFitStatus(model) : \"too_large\";\n    }}\n    onSelect={(modelId) => {\n      if (modelPickerContext === \"chat\") {\n        handleChatPickerSelect(modelId);\n      } else {\n        handleModelPickerSelect(modelId);\n      }\n    }}\n    onClose={() => (isModelPickerOpen = false)}\n    onToggleFavorite={toggleFavorite}\n    onAddModel={addModelFromPicker}\n    onDeleteModel={deleteCustomModel}\n    totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}\n    usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}\n    {downloadsData}\n    topologyNodes={data?.nodes}\n    instanceStatuses={modelInstanceStatuses}\n  />\n{/if}\n"
  },
  {
    "path": "dashboard/src/routes/downloads/+page.svelte",
    "content": "<script lang=\"ts\">\n  import { onMount } from \"svelte\";\n  import { fade, fly } from \"svelte/transition\";\n  import { cubicOut } from \"svelte/easing\";\n  import {\n    topologyData,\n    downloads,\n    nodeDisk,\n    refreshState,\n    lastUpdate as lastUpdateStore,\n    startDownload,\n    deleteDownload,\n  } from \"$lib/stores/app.svelte\";\n  import {\n    getDownloadTag,\n    extractModelIdFromDownload,\n    extractShardMetadata,\n  } from \"$lib/utils/downloads\";\n  import HeaderNav from \"$lib/components/HeaderNav.svelte\";\n\n  type CellStatus =\n    | { kind: \"completed\"; totalBytes: number; modelDirectory?: string }\n    | {\n        kind: \"downloading\";\n        percentage: number;\n        downloadedBytes: number;\n        totalBytes: number;\n        speed: number;\n        etaMs: number;\n        modelDirectory?: string;\n      }\n    | {\n        kind: \"pending\";\n        downloaded: number;\n        total: number;\n        modelDirectory?: string;\n      }\n    | { kind: \"failed\"; modelDirectory?: string }\n    | { kind: \"not_present\" };\n\n  type ModelCardInfo = {\n    family: string;\n    quantization: string;\n    baseModel: string;\n    capabilities: string[];\n    storageSize: number;\n    nLayers: number;\n    supportsTensor: boolean;\n  };\n\n  type ModelRow = {\n    modelId: string;\n    prettyName: string | null;\n    cells: Record<string, CellStatus>;\n    shardMetadata: Record<string, unknown> | null;\n    modelCard: ModelCardInfo | null;\n  };\n\n  type NodeColumn = {\n    nodeId: string;\n    label: string;\n    diskAvailable?: number;\n    diskTotal?: number;\n  };\n\n  const data = $derived(topologyData());\n  const downloadsData = $derived(downloads());\n  const nodeDiskData = $derived(nodeDisk());\n\n  function getNodeLabel(nodeId: string): string {\n    const node = data?.nodes?.[nodeId];\n    if (!node) return nodeId.slice(0, 8);\n    return (\n      node.friendly_name || node.system_info?.model_id || nodeId.slice(0, 8)\n    );\n  }\n\n  function getBytes(value: unknown): number {\n    if (typeof value === \"number\") return value;\n    if (value && typeof value === \"object\") {\n      const v = value as Record<string, unknown>;\n      if (typeof v.inBytes === \"number\") return v.inBytes;\n    }\n    return 0;\n  }\n\n  function formatBytes(bytes: number): string {\n    if (!bytes || bytes <= 0) return \"0B\";\n    const units = [\"B\", \"KB\", \"MB\", \"GB\", \"TB\"];\n    const i = Math.min(\n      Math.floor(Math.log(bytes) / Math.log(1024)),\n      units.length - 1,\n    );\n    const val = bytes / Math.pow(1024, i);\n    return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;\n  }\n\n  function formatEta(ms: number): string {\n    if (!ms || ms <= 0) return \"--\";\n    const totalSeconds = Math.round(ms / 1000);\n    const s = totalSeconds % 60;\n    const m = Math.floor(totalSeconds / 60) % 60;\n    const h = Math.floor(totalSeconds / 3600);\n    if (h > 0) return `${h}h ${m}m`;\n    if (m > 0) return `${m}m ${s}s`;\n    return `${s}s`;\n  }\n\n  function formatSpeed(bytesPerSecond: number): string {\n    if (!bytesPerSecond || bytesPerSecond <= 0) return \"--\";\n    const units = [\"B/s\", \"KB/s\", \"MB/s\", \"GB/s\"];\n    const i = Math.min(\n      Math.floor(Math.log(bytesPerSecond) / Math.log(1024)),\n      units.length - 1,\n    );\n    const val = bytesPerSecond / Math.pow(1024, i);\n    return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;\n  }\n\n  function clampPercent(value: number | undefined): number {\n    if (!Number.isFinite(value)) return 0;\n    return Math.min(100, Math.max(0, value as number));\n  }\n\n  const CELL_PRIORITY: Record<CellStatus[\"kind\"], number> = {\n    completed: 4,\n    downloading: 3,\n    pending: 2,\n    failed: 1,\n    not_present: 0,\n  };\n\n  function shouldUpgradeCell(\n    existing: CellStatus,\n    candidate: CellStatus,\n  ): boolean {\n    return CELL_PRIORITY[candidate.kind] > CELL_PRIORITY[existing.kind];\n  }\n\n  function extractModelCard(payload: Record<string, unknown>): {\n    prettyName: string | null;\n    card: ModelCardInfo | null;\n  } {\n    const shardMetadata = payload.shard_metadata ?? payload.shardMetadata;\n    if (!shardMetadata || typeof shardMetadata !== \"object\")\n      return { prettyName: null, card: null };\n    const shardObj = shardMetadata as Record<string, unknown>;\n    const shardKeys = Object.keys(shardObj);\n    if (shardKeys.length !== 1) return { prettyName: null, card: null };\n    const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;\n    const modelMeta = shardData?.model_card ?? shardData?.modelCard;\n    if (!modelMeta || typeof modelMeta !== \"object\")\n      return { prettyName: null, card: null };\n    const meta = modelMeta as Record<string, unknown>;\n\n    const prettyName = (meta.prettyName as string) ?? null;\n\n    const card: ModelCardInfo = {\n      family: (meta.family as string) ?? \"\",\n      quantization: (meta.quantization as string) ?? \"\",\n      baseModel:\n        (meta.base_model as string) ?? (meta.baseModel as string) ?? \"\",\n      capabilities: Array.isArray(meta.capabilities)\n        ? (meta.capabilities as string[])\n        : [],\n      storageSize: getBytes(meta.storage_size ?? meta.storageSize),\n      nLayers: (meta.n_layers as number) ?? (meta.nLayers as number) ?? 0,\n      supportsTensor:\n        (meta.supports_tensor as boolean) ??\n        (meta.supportsTensor as boolean) ??\n        false,\n    };\n\n    return { prettyName, card };\n  }\n\n  let modelRows = $state<ModelRow[]>([]);\n  let nodeColumns = $state<NodeColumn[]>([]);\n  let infoRow = $state<ModelRow | null>(null);\n\n  $effect(() => {\n    try {\n      if (!downloadsData || Object.keys(downloadsData).length === 0) {\n        modelRows = [];\n        nodeColumns = [];\n        return;\n      }\n\n      const allNodeIds = Object.keys(downloadsData);\n      const columns: NodeColumn[] = allNodeIds.map((nodeId) => {\n        const diskInfo = nodeDiskData?.[nodeId];\n        return {\n          nodeId,\n          label: getNodeLabel(nodeId),\n          diskAvailable: diskInfo?.available?.inBytes,\n          diskTotal: diskInfo?.total?.inBytes,\n        };\n      });\n\n      const rowMap = new Map<string, ModelRow>();\n\n      for (const [nodeId, nodeDownloads] of Object.entries(downloadsData)) {\n        const entries = Array.isArray(nodeDownloads)\n          ? nodeDownloads\n          : nodeDownloads && typeof nodeDownloads === \"object\"\n            ? Object.values(nodeDownloads as Record<string, unknown>)\n            : [];\n\n        for (const entry of entries) {\n          const tagged = getDownloadTag(entry);\n          if (!tagged) continue;\n          const [tag, payload] = tagged;\n\n          const modelId =\n            extractModelIdFromDownload(payload) ?? \"unknown-model\";\n          const { prettyName, card } = extractModelCard(payload);\n\n          if (!rowMap.has(modelId)) {\n            rowMap.set(modelId, {\n              modelId,\n              prettyName,\n              cells: {},\n              shardMetadata: extractShardMetadata(payload),\n              modelCard: card,\n            });\n          }\n          const row = rowMap.get(modelId)!;\n          if (prettyName && !row.prettyName) row.prettyName = prettyName;\n          if (!row.shardMetadata)\n            row.shardMetadata = extractShardMetadata(payload);\n          if (!row.modelCard && card) row.modelCard = card;\n\n          const modelDirectory =\n            ((payload.model_directory ?? payload.modelDirectory) as string) ||\n            undefined;\n          let cell: CellStatus;\n          if (tag === \"DownloadCompleted\") {\n            const totalBytes = getBytes(payload.total);\n            cell = { kind: \"completed\", totalBytes, modelDirectory };\n          } else if (tag === \"DownloadOngoing\") {\n            const rawProgress =\n              payload.download_progress ?? payload.downloadProgress ?? {};\n            const prog = rawProgress as Record<string, unknown>;\n            const totalBytes = getBytes(prog.total ?? payload.total);\n            const downloadedBytes = getBytes(prog.downloaded);\n            const speed = (prog.speed as number) ?? 0;\n            const etaMs =\n              (prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;\n            const percentage =\n              totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0;\n            cell = {\n              kind: \"downloading\",\n              percentage: clampPercent(percentage),\n              downloadedBytes,\n              totalBytes,\n              speed,\n              etaMs,\n              modelDirectory,\n            };\n          } else if (tag === \"DownloadFailed\") {\n            cell = { kind: \"failed\", modelDirectory };\n          } else {\n            const downloaded = getBytes(\n              payload.downloaded ??\n                payload.downloaded_bytes ??\n                payload.downloadedBytes,\n            );\n            const total = getBytes(\n              payload.total ?? payload.total_bytes ?? payload.totalBytes,\n            );\n            cell = {\n              kind: \"pending\",\n              downloaded,\n              total,\n              modelDirectory,\n            };\n          }\n\n          const existing = row.cells[nodeId];\n          if (!existing || shouldUpgradeCell(existing, cell)) {\n            row.cells[nodeId] = cell;\n          }\n        }\n      }\n\n      function rowSortKey(row: ModelRow): number {\n        // in progress (4) -> completed (3) -> paused (2) -> not started (1) -> not present (0)\n        let best = 0;\n        for (const cell of Object.values(row.cells)) {\n          let score = 0;\n          if (cell.kind === \"downloading\") score = 4;\n          else if (cell.kind === \"completed\") score = 3;\n          else if (cell.kind === \"pending\" && cell.downloaded > 0)\n            score = 2; // paused\n          else if (cell.kind === \"pending\" || cell.kind === \"failed\") score = 1; // not started\n          if (score > best) best = score;\n        }\n        return best;\n      }\n\n      function totalCompletedBytes(row: ModelRow): number {\n        let total = 0;\n        for (const cell of Object.values(row.cells)) {\n          if (cell.kind === \"completed\") total += cell.totalBytes;\n        }\n        return total;\n      }\n\n      const rows = Array.from(rowMap.values()).sort((a, b) => {\n        const aPriority = rowSortKey(a);\n        const bPriority = rowSortKey(b);\n        if (aPriority !== bPriority) return bPriority - aPriority;\n        // Within completed or paused, sort by biggest size first\n        if (aPriority === 3 && bPriority === 3) {\n          const sizeDiff = totalCompletedBytes(b) - totalCompletedBytes(a);\n          if (sizeDiff !== 0) return sizeDiff;\n        }\n        if (aPriority === 2 && bPriority === 2) {\n          const aSize = Math.max(\n            ...Object.values(a.cells).map((c) =>\n              c.kind === \"pending\" ? c.total : 0,\n            ),\n          );\n          const bSize = Math.max(\n            ...Object.values(b.cells).map((c) =>\n              c.kind === \"pending\" ? c.total : 0,\n            ),\n          );\n          if (aSize !== bSize) return bSize - aSize;\n        }\n        return a.modelId.localeCompare(b.modelId);\n      });\n\n      modelRows = rows;\n      nodeColumns = columns;\n    } catch (err) {\n      console.error(\"Parse downloads error\", err);\n      modelRows = [];\n      nodeColumns = [];\n    }\n  });\n\n  const hasDownloads = $derived(modelRows.length > 0);\n  const lastUpdateTs = $derived(lastUpdateStore());\n  const downloadKeys = $derived(Object.keys(downloadsData || {}));\n\n  onMount(() => {\n    refreshState();\n  });\n</script>\n\n<div class=\"min-h-screen bg-exo-dark-gray text-white\">\n  <HeaderNav showHome={true} />\n  <div class=\"max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6\">\n    <div class=\"flex items-center justify-between gap-4 flex-wrap\">\n      <div>\n        <h1\n          class=\"text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow\"\n        >\n          Downloads\n        </h1>\n        <p class=\"text-sm text-exo-light-gray\">\n          Overview of models on each node\n        </p>\n      </div>\n      <div class=\"flex items-center gap-3\">\n        <button\n          type=\"button\"\n          class=\"text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded\"\n          onclick={() => refreshState()}\n          title=\"Force refresh from /state\"\n        >\n          Refresh\n        </button>\n        <div class=\"text-[11px] font-mono text-exo-light-gray\">\n          Last update: {lastUpdateTs\n            ? new Date(lastUpdateTs).toLocaleTimeString()\n            : \"n/a\"}\n        </div>\n      </div>\n    </div>\n\n    {#if !hasDownloads}\n      <div\n        class=\"rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray space-y-2\"\n      >\n        <div class=\"text-sm\">\n          No downloads found. Start a model download to see progress here.\n        </div>\n        <div class=\"text-[11px] text-exo-light-gray/70\">\n          Download keys detected: {downloadKeys.length === 0\n            ? \"none\"\n            : downloadKeys.join(\", \")}\n        </div>\n      </div>\n    {:else}\n      <div\n        class=\"rounded border border-exo-medium-gray/30 bg-exo-black/30 overflow-x-auto\"\n      >\n        <table class=\"w-full text-left font-mono text-xs\">\n          <thead>\n            <tr class=\"border-b border-exo-medium-gray/30\">\n              <th\n                class=\"sticky left-0 z-10 bg-exo-black px-4 py-3 text-[11px] uppercase tracking-wider text-exo-yellow font-medium whitespace-nowrap border-r border-exo-medium-gray/20\"\n              >\n                Model\n              </th>\n              {#each nodeColumns as col}\n                <th\n                  class=\"px-4 py-3 text-[11px] uppercase tracking-wider text-exo-light-gray font-medium text-center whitespace-nowrap min-w-[120px]\"\n                >\n                  <div>{col.label}</div>\n                  {#if col.diskAvailable != null}\n                    <div\n                      class=\"text-[9px] text-white/70 normal-case tracking-normal mt-0.5\"\n                    >\n                      {formatBytes(col.diskAvailable)} free\n                    </div>\n                  {/if}\n                </th>\n              {/each}\n            </tr>\n          </thead>\n          <tbody>\n            {#each modelRows as row}\n              <tr\n                class=\"group border-b border-exo-medium-gray/20 hover:bg-exo-medium-gray/10 transition-colors\"\n              >\n                <td\n                  class=\"sticky left-0 z-10 bg-exo-dark-gray group-hover:bg-[oklch(0.18_0_0)] transition-colors px-4 py-3 whitespace-nowrap border-r border-exo-medium-gray/20\"\n                >\n                  <div class=\"flex items-center gap-2\">\n                    <div class=\"min-w-0\">\n                      <div class=\"text-white text-xs\" title={row.modelId}>\n                        {row.prettyName ?? row.modelId}\n                      </div>\n                      {#if row.prettyName}\n                        <div\n                          class=\"text-[10px] text-white/60\"\n                          title={row.modelId}\n                        >\n                          {row.modelId}\n                        </div>\n                      {/if}\n                    </div>\n                    <button\n                      type=\"button\"\n                      class=\"p-1 rounded hover:bg-white/10 transition-colors flex-shrink-0 opacity-60 group-hover:opacity-100\"\n                      onclick={() => (infoRow = row)}\n                      title=\"View model details\"\n                    >\n                      <svg\n                        class=\"w-4 h-4 text-white/60 hover:text-white/80\"\n                        viewBox=\"0 0 24 24\"\n                        fill=\"currentColor\"\n                      >\n                        <path\n                          d=\"M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm1 15h-2v-6h2v6zm0-8h-2V7h2v2z\"\n                        />\n                      </svg>\n                    </button>\n                  </div>\n                </td>\n\n                {#each nodeColumns as col}\n                  {@const cell = row.cells[col.nodeId] ?? {\n                    kind: \"not_present\" as const,\n                  }}\n                  <td class=\"px-4 py-3 text-center align-middle\">\n                    {#if cell.kind === \"completed\"}\n                      <div\n                        class=\"flex flex-col items-center gap-1\"\n                        title=\"Completed ({formatBytes(cell.totalBytes)})\"\n                      >\n                        <svg\n                          class=\"w-7 h-7 text-green-400\"\n                          viewBox=\"0 0 20 20\"\n                          fill=\"currentColor\"\n                        >\n                          <path\n                            fill-rule=\"evenodd\"\n                            d=\"M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z\"\n                            clip-rule=\"evenodd\"\n                          ></path>\n                        </svg>\n                        <span class=\"text-xs text-white/70\"\n                          >{formatBytes(cell.totalBytes)}</span\n                        >\n                        <button\n                          type=\"button\"\n                          class=\"text-white/50 hover:text-red-400 transition-colors mt-0.5 cursor-pointer\"\n                          onclick={() =>\n                            deleteDownload(col.nodeId, row.modelId)}\n                          title=\"Delete from this node\"\n                        >\n                          <svg\n                            class=\"w-5 h-5\"\n                            viewBox=\"0 0 20 20\"\n                            fill=\"none\"\n                            stroke=\"currentColor\"\n                            stroke-width=\"2\"\n                          >\n                            <path\n                              d=\"M4 6h12M8 6V4h4v2m1 0v10a1 1 0 01-1 1H8a1 1 0 01-1-1V6h6\"\n                              stroke-linecap=\"round\"\n                              stroke-linejoin=\"round\"\n                            ></path>\n                          </svg>\n                        </button>\n                      </div>\n                    {:else if cell.kind === \"downloading\"}\n                      <div\n                        class=\"flex flex-col items-center gap-1\"\n                        title=\"{formatBytes(\n                          cell.downloadedBytes,\n                        )} / {formatBytes(cell.totalBytes)} - {formatSpeed(\n                          cell.speed,\n                        )} - ETA {formatEta(cell.etaMs)}\"\n                      >\n                        <span class=\"text-exo-yellow text-sm font-medium\"\n                          >{clampPercent(cell.percentage).toFixed(1)}%</span\n                        >\n                        <div\n                          class=\"w-16 h-2 bg-exo-black/60 rounded-sm overflow-hidden\"\n                        >\n                          <div\n                            class=\"h-full bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300\"\n                            style=\"width: {clampPercent(\n                              cell.percentage,\n                            ).toFixed(1)}%\"\n                          ></div>\n                        </div>\n                        <span class=\"text-[10px] text-white/70\"\n                          >{formatSpeed(cell.speed)}</span\n                        >\n                      </div>\n                    {:else if cell.kind === \"pending\"}\n                      <div\n                        class=\"flex flex-col items-center gap-1\"\n                        title={cell.downloaded > 0\n                          ? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded (paused)`\n                          : \"Download pending\"}\n                      >\n                        {#if cell.downloaded > 0 && cell.total > 0}\n                          <span class=\"text-white/70 text-xs\"\n                            >{formatBytes(cell.downloaded)} / {formatBytes(\n                              cell.total,\n                            )}</span\n                          >\n                          <div\n                            class=\"w-full h-1.5 bg-white/10 rounded-full overflow-hidden\"\n                          >\n                            <div\n                              class=\"h-full bg-exo-light-gray/40 rounded-full\"\n                              style=\"width: {(\n                                (cell.downloaded / cell.total) *\n                                100\n                              ).toFixed(1)}%\"\n                            ></div>\n                          </div>\n                          {#if row.shardMetadata}\n                            <button\n                              type=\"button\"\n                              class=\"text-white/50 hover:text-exo-yellow transition-colors cursor-pointer\"\n                              onclick={() =>\n                                startDownload(col.nodeId, row.shardMetadata!)}\n                              title=\"Resume download on this node\"\n                            >\n                              <svg\n                                class=\"w-5 h-5\"\n                                viewBox=\"0 0 20 20\"\n                                fill=\"none\"\n                                stroke=\"currentColor\"\n                                stroke-width=\"2\"\n                              >\n                                <path\n                                  d=\"M10 3v10m0 0l-3-3m3 3l3-3M3 17h14\"\n                                  stroke-linecap=\"round\"\n                                  stroke-linejoin=\"round\"\n                                ></path>\n                              </svg>\n                            </button>\n                          {:else}\n                            <span class=\"text-white/50 text-[10px]\">paused</span\n                            >\n                          {/if}\n                        {:else if row.shardMetadata}\n                          <button\n                            type=\"button\"\n                            class=\"text-white/50 hover:text-exo-yellow transition-colors cursor-pointer\"\n                            onclick={() =>\n                              startDownload(col.nodeId, row.shardMetadata!)}\n                            title=\"Start download on this node\"\n                          >\n                            <svg\n                              class=\"w-6 h-6\"\n                              viewBox=\"0 0 20 20\"\n                              fill=\"none\"\n                              stroke=\"currentColor\"\n                              stroke-width=\"2\"\n                            >\n                              <path\n                                d=\"M10 3v10m0 0l-3-3m3 3l3-3M3 17h14\"\n                                stroke-linecap=\"round\"\n                                stroke-linejoin=\"round\"\n                              ></path>\n                            </svg>\n                          </button>\n                        {:else}\n                          <span class=\"text-white/40 text-sm\">...</span>\n                        {/if}\n                      </div>\n                    {:else if cell.kind === \"failed\"}\n                      <div\n                        class=\"flex flex-col items-center gap-1\"\n                        title=\"Download failed\"\n                      >\n                        <svg\n                          class=\"w-7 h-7 text-red-400\"\n                          viewBox=\"0 0 20 20\"\n                          fill=\"currentColor\"\n                        >\n                          <path\n                            fill-rule=\"evenodd\"\n                            d=\"M4.293 4.293a1 1 0 011.414 0L10 8.586l4.293-4.293a1 1 0 111.414 1.414L11.414 10l4.293 4.293a1 1 0 01-1.414 1.414L10 11.414l-4.293 4.293a1 1 0 01-1.414-1.414L8.586 10 4.293 5.707a1 1 0 010-1.414z\"\n                            clip-rule=\"evenodd\"\n                          ></path>\n                        </svg>\n                        {#if row.shardMetadata}\n                          <button\n                            type=\"button\"\n                            class=\"text-white/50 hover:text-exo-yellow transition-colors cursor-pointer\"\n                            onclick={() =>\n                              startDownload(col.nodeId, row.shardMetadata!)}\n                            title=\"Retry download on this node\"\n                          >\n                            <svg\n                              class=\"w-5 h-5\"\n                              viewBox=\"0 0 20 20\"\n                              fill=\"none\"\n                              stroke=\"currentColor\"\n                              stroke-width=\"2\"\n                            >\n                              <path\n                                d=\"M10 3v10m0 0l-3-3m3 3l3-3M3 17h14\"\n                                stroke-linecap=\"round\"\n                                stroke-linejoin=\"round\"\n                              ></path>\n                            </svg>\n                          </button>\n                        {/if}\n                      </div>\n                    {:else}\n                      <div\n                        class=\"flex flex-col items-center\"\n                        title=\"Not on this node\"\n                      >\n                        <span class=\"text-exo-medium-gray text-lg leading-none\"\n                          >--</span\n                        >\n                        {#if row.shardMetadata}\n                          <button\n                            type=\"button\"\n                            class=\"text-white/50 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100 cursor-pointer\"\n                            onclick={() =>\n                              startDownload(col.nodeId, row.shardMetadata!)}\n                            title=\"Download to this node\"\n                          >\n                            <svg\n                              class=\"w-5 h-5\"\n                              viewBox=\"0 0 20 20\"\n                              fill=\"none\"\n                              stroke=\"currentColor\"\n                              stroke-width=\"2\"\n                            >\n                              <path\n                                d=\"M10 3v10m0 0l-3-3m3 3l3-3M3 17h14\"\n                                stroke-linecap=\"round\"\n                                stroke-linejoin=\"round\"\n                              ></path>\n                            </svg>\n                          </button>\n                        {/if}\n                      </div>\n                    {/if}\n                  </td>\n                {/each}\n              </tr>\n            {/each}\n          </tbody>\n        </table>\n      </div>\n    {/if}\n  </div>\n</div>\n\n<!-- Info modal -->\n{#if infoRow}\n  <div\n    class=\"fixed inset-0 z-[60] bg-black/60\"\n    transition:fade={{ duration: 150 }}\n    onclick={() => (infoRow = null)}\n    role=\"presentation\"\n  ></div>\n  <div\n    class=\"fixed z-[60] top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(80vw,400px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl p-4\"\n    transition:fly={{ y: 10, duration: 200, easing: cubicOut }}\n    role=\"dialog\"\n    aria-modal=\"true\"\n  >\n    <div class=\"flex items-start justify-between mb-3\">\n      <h3 class=\"font-mono text-lg text-white\">\n        {infoRow.prettyName ?? infoRow.modelId}\n      </h3>\n      <button\n        type=\"button\"\n        class=\"p-1 rounded hover:bg-white/10 transition-colors text-white/50\"\n        onclick={() => (infoRow = null)}\n        title=\"Close model details\"\n        aria-label=\"Close info dialog\"\n      >\n        <svg class=\"w-4 h-4\" viewBox=\"0 0 24 24\" fill=\"currentColor\">\n          <path\n            d=\"M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z\"\n          />\n        </svg>\n      </button>\n    </div>\n    <div class=\"space-y-2 text-xs font-mono\">\n      <div class=\"flex items-center gap-2\">\n        <span class=\"text-white/40\">Model ID:</span>\n        <span class=\"text-white/70\">{infoRow.modelId}</span>\n      </div>\n      {#if infoRow.modelCard}\n        {#if infoRow.modelCard.family}\n          <div class=\"flex items-center gap-2\">\n            <span class=\"text-white/40\">Family:</span>\n            <span class=\"text-white/70\">{infoRow.modelCard.family}</span>\n          </div>\n        {/if}\n        {#if infoRow.modelCard.baseModel}\n          <div class=\"flex items-center gap-2\">\n            <span class=\"text-white/40\">Base model:</span>\n            <span class=\"text-white/70\">{infoRow.modelCard.baseModel}</span>\n          </div>\n        {/if}\n        {#if infoRow.modelCard.quantization}\n          <div class=\"flex items-center gap-2\">\n            <span class=\"text-white/40\">Quantization:</span>\n            <span class=\"text-white/70\">{infoRow.modelCard.quantization}</span>\n          </div>\n        {/if}\n        {#if infoRow.modelCard.storageSize > 0}\n          <div class=\"flex items-center gap-2\">\n            <span class=\"text-white/40\">Size:</span>\n            <span class=\"text-white/70\"\n              >{formatBytes(infoRow.modelCard.storageSize)}</span\n            >\n          </div>\n        {/if}\n        {#if infoRow.modelCard.nLayers > 0}\n          <div class=\"flex items-center gap-2\">\n            <span class=\"text-white/40\">Layers:</span>\n            <span class=\"text-white/70\">{infoRow.modelCard.nLayers}</span>\n          </div>\n        {/if}\n        {#if infoRow.modelCard.capabilities.length > 0}\n          <div class=\"flex items-center gap-2\">\n            <span class=\"text-white/40\">Capabilities:</span>\n            <span class=\"text-white/70\"\n              >{infoRow.modelCard.capabilities.join(\", \")}</span\n            >\n          </div>\n        {/if}\n        <div class=\"flex items-center gap-2\">\n          <span class=\"text-white/40\">Tensor parallelism:</span>\n          <span class=\"text-white/70\"\n            >{infoRow.modelCard.supportsTensor ? \"Yes\" : \"No\"}</span\n          >\n        </div>\n      {/if}\n\n      <!-- Per-node download status -->\n      {#if nodeColumns.filter((col) => (infoRow?.cells[col.nodeId]?.kind ?? \"not_present\") !== \"not_present\").length > 0}\n        <div class=\"mt-3 pt-3 border-t border-exo-yellow/10\">\n          <div class=\"flex items-center gap-2 mb-1\">\n            <svg\n              class=\"w-3.5 h-3.5\"\n              viewBox=\"0 0 24 24\"\n              fill=\"none\"\n              stroke=\"currentColor\"\n              stroke-width=\"2\"\n              stroke-linecap=\"round\"\n              stroke-linejoin=\"round\"\n            >\n              <path\n                class=\"text-white/40\"\n                d=\"M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z\"\n              />\n              <path class=\"text-green-400\" d=\"m9 13 2 2 4-4\" />\n            </svg>\n            <span class=\"text-white/40\">On nodes:</span>\n          </div>\n          <div class=\"flex flex-col gap-1.5 mt-1\">\n            {#each nodeColumns as col}\n              {@const cellStatus = infoRow?.cells[col.nodeId]}\n              {#if cellStatus && cellStatus.kind !== \"not_present\"}\n                <div class=\"flex flex-col gap-0.5\">\n                  <span\n                    class=\"inline-block w-fit px-1.5 py-0.5 rounded text-[10px] {cellStatus.kind ===\n                    'completed'\n                      ? 'bg-green-500/10 text-green-400/80 border border-green-500/20'\n                      : cellStatus.kind === 'downloading'\n                        ? 'bg-exo-yellow/10 text-exo-yellow/80 border border-exo-yellow/20'\n                        : cellStatus.kind === 'failed'\n                          ? 'bg-red-500/10 text-red-400/80 border border-red-500/20'\n                          : 'bg-white/5 text-white/50 border border-white/10'}\"\n                  >\n                    {col.label}\n                    {#if cellStatus.kind === \"downloading\" && \"percentage\" in cellStatus}\n                      ({clampPercent(cellStatus.percentage).toFixed(0)}%)\n                    {/if}\n                  </span>\n                  {#if \"modelDirectory\" in cellStatus && cellStatus.modelDirectory}\n                    <span\n                      class=\"text-[9px] text-white/30 break-all pl-1\"\n                      title={cellStatus.modelDirectory}\n                    >\n                      {cellStatus.modelDirectory}\n                    </span>\n                  {/if}\n                </div>\n              {/if}\n            {/each}\n          </div>\n        </div>\n      {/if}\n    </div>\n  </div>\n{/if}\n\n<style>\n  table {\n    min-width: max-content;\n  }\n</style>\n"
  },
  {
    "path": "dashboard/src/routes/traces/+page.svelte",
    "content": "<script lang=\"ts\">\n  import { onMount } from \"svelte\";\n  import {\n    listTraces,\n    getTraceRawUrl,\n    deleteTraces,\n    type TraceListItem,\n  } from \"$lib/stores/app.svelte\";\n  import HeaderNav from \"$lib/components/HeaderNav.svelte\";\n\n  let traces = $state<TraceListItem[]>([]);\n  let loading = $state(true);\n  let error = $state<string | null>(null);\n  let selectedIds = $state<Set<string>>(new Set());\n  let deleting = $state(false);\n\n  let allSelected = $derived(\n    traces.length > 0 && selectedIds.size === traces.length,\n  );\n\n  function toggleSelect(taskId: string) {\n    const next = new Set(selectedIds);\n    if (next.has(taskId)) {\n      next.delete(taskId);\n    } else {\n      next.add(taskId);\n    }\n    selectedIds = next;\n  }\n\n  function toggleSelectAll() {\n    if (allSelected) {\n      selectedIds = new Set();\n    } else {\n      selectedIds = new Set(traces.map((t) => t.taskId));\n    }\n  }\n\n  async function handleDelete() {\n    if (selectedIds.size === 0) return;\n    const count = selectedIds.size;\n    if (\n      !confirm(\n        `Delete ${count} trace${count === 1 ? \"\" : \"s\"}? This cannot be undone.`,\n      )\n    )\n      return;\n    deleting = true;\n    try {\n      await deleteTraces([...selectedIds]);\n      selectedIds = new Set();\n      await refresh();\n    } catch (e) {\n      error = e instanceof Error ? e.message : \"Failed to delete traces\";\n    } finally {\n      deleting = false;\n    }\n  }\n\n  function formatBytes(bytes: number): string {\n    if (!bytes || bytes <= 0) return \"0B\";\n    const units = [\"B\", \"KB\", \"MB\", \"GB\"];\n    const i = Math.min(\n      Math.floor(Math.log(bytes) / Math.log(1024)),\n      units.length - 1,\n    );\n    const val = bytes / Math.pow(1024, i);\n    return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;\n  }\n\n  function formatDate(isoString: string): string {\n    const date = new Date(isoString);\n    return date.toLocaleString();\n  }\n\n  async function downloadTrace(taskId: string) {\n    const response = await fetch(getTraceRawUrl(taskId));\n    const blob = await response.blob();\n    const url = URL.createObjectURL(blob);\n    const a = document.createElement(\"a\");\n    a.href = url;\n    a.download = `trace_${taskId}.json`;\n    a.click();\n    URL.revokeObjectURL(url);\n  }\n\n  async function openInPerfetto(taskId: string) {\n    // Fetch trace data from our local API\n    const response = await fetch(getTraceRawUrl(taskId));\n    const traceData = await response.arrayBuffer();\n\n    // Open Perfetto UI\n    const perfettoWindow = window.open(\"https://ui.perfetto.dev\");\n    if (!perfettoWindow) {\n      alert(\"Failed to open Perfetto. Please allow popups.\");\n      return;\n    }\n\n    // Wait for Perfetto to be ready, then send trace via postMessage\n    const onMessage = (e: MessageEvent) => {\n      if (e.data === \"PONG\") {\n        window.removeEventListener(\"message\", onMessage);\n        perfettoWindow.postMessage(\n          {\n            perfetto: {\n              buffer: traceData,\n              title: `Trace ${taskId}`,\n            },\n          },\n          \"https://ui.perfetto.dev\",\n        );\n      }\n    };\n    window.addEventListener(\"message\", onMessage);\n\n    // Ping Perfetto until it responds\n    const pingInterval = setInterval(() => {\n      perfettoWindow.postMessage(\"PING\", \"https://ui.perfetto.dev\");\n    }, 50);\n\n    // Clean up after 10 seconds\n    setTimeout(() => {\n      clearInterval(pingInterval);\n      window.removeEventListener(\"message\", onMessage);\n    }, 10000);\n  }\n\n  async function refresh() {\n    loading = true;\n    error = null;\n    try {\n      const response = await listTraces();\n      traces = response.traces;\n    } catch (e) {\n      error = e instanceof Error ? e.message : \"Failed to load traces\";\n    } finally {\n      loading = false;\n    }\n  }\n\n  onMount(() => {\n    refresh();\n  });\n</script>\n\n<div class=\"min-h-screen bg-exo-dark-gray text-white\">\n  <HeaderNav showHome={true} />\n  <div class=\"max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6\">\n    <div class=\"flex items-center justify-between gap-4 flex-wrap\">\n      <div>\n        <h1\n          class=\"text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow\"\n        >\n          Traces\n        </h1>\n      </div>\n      <div class=\"flex items-center gap-3\">\n        {#if selectedIds.size > 0}\n          <button\n            type=\"button\"\n            class=\"text-xs font-mono text-red-400 hover:text-red-300 transition-colors uppercase border border-red-500/40 px-2 py-1 rounded\"\n            onclick={handleDelete}\n            disabled={deleting}\n          >\n            {deleting ? \"Deleting...\" : `Delete (${selectedIds.size})`}\n          </button>\n        {/if}\n        <button\n          type=\"button\"\n          class=\"text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded\"\n          onclick={refresh}\n          disabled={loading}\n        >\n          Refresh\n        </button>\n      </div>\n    </div>\n\n    {#if loading}\n      <div\n        class=\"rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray\"\n      >\n        <div class=\"text-sm\">Loading traces...</div>\n      </div>\n    {:else if error}\n      <div\n        class=\"rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400\"\n      >\n        <div class=\"text-sm\">{error}</div>\n      </div>\n    {:else if traces.length === 0}\n      <div\n        class=\"rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray space-y-2\"\n      >\n        <div class=\"text-sm\">No traces found.</div>\n        <div class=\"text-xs text-exo-light-gray/70\">\n          Run exo with EXO_TRACING_ENABLED=1 to collect traces.\n        </div>\n      </div>\n    {:else}\n      <div class=\"space-y-3\">\n        <div class=\"flex items-center gap-2 px-1\">\n          <button\n            type=\"button\"\n            class=\"text-xs font-mono uppercase transition-colors {allSelected\n              ? 'text-exo-yellow'\n              : 'text-exo-light-gray hover:text-exo-yellow'}\"\n            onclick={toggleSelectAll}\n          >\n            {allSelected ? \"Deselect all\" : \"Select all\"}\n          </button>\n        </div>\n        {#each traces as trace}\n          {@const isSelected = selectedIds.has(trace.taskId)}\n          <!-- svelte-ignore a11y_no_static_element_interactions -->\n          <div\n            role=\"button\"\n            tabindex=\"0\"\n            class=\"w-full text-left rounded border-l-2 border-r border-t border-b transition-all p-4 flex items-center justify-between gap-4 cursor-pointer {isSelected\n              ? 'bg-exo-yellow/10 border-l-exo-yellow border-r-exo-medium-gray/30 border-t-exo-medium-gray/30 border-b-exo-medium-gray/30'\n              : 'bg-exo-black/30 border-l-transparent border-r-exo-medium-gray/30 border-t-exo-medium-gray/30 border-b-exo-medium-gray/30 hover:bg-white/[0.03]'}\"\n            onclick={() => toggleSelect(trace.taskId)}\n            onkeydown={(e) => {\n              if (e.key === \"Enter\" || e.key === \" \") {\n                e.preventDefault();\n                toggleSelect(trace.taskId);\n              }\n            }}\n          >\n            <div class=\"min-w-0 flex-1\">\n              <a\n                href=\"#/traces/{trace.taskId}\"\n                class=\"text-sm font-mono transition-colors truncate block {isSelected\n                  ? 'text-exo-yellow'\n                  : 'text-white hover:text-exo-yellow'}\"\n                onclick={(e) => e.stopPropagation()}\n              >\n                {trace.taskId}\n              </a>\n              <div class=\"text-xs text-exo-light-gray font-mono mt-1\">\n                {formatDate(trace.createdAt)} &bull; {formatBytes(\n                  trace.fileSize,\n                )}\n              </div>\n            </div>\n            <!-- svelte-ignore a11y_click_events_have_key_events -->\n            <div\n              class=\"flex items-center gap-2 shrink-0\"\n              onclick={(e) => e.stopPropagation()}\n            >\n              <a\n                href=\"#/traces/{trace.taskId}\"\n                class=\"text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded\"\n              >\n                View Stats\n              </a>\n              <button\n                type=\"button\"\n                class=\"text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded\"\n                onclick={() => downloadTrace(trace.taskId)}\n              >\n                Download\n              </button>\n              <button\n                type=\"button\"\n                class=\"text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-2 py-1 rounded font-semibold\"\n                onclick={() => openInPerfetto(trace.taskId)}\n              >\n                View Trace\n              </button>\n            </div>\n          </div>\n        {/each}\n      </div>\n    {/if}\n  </div>\n</div>\n"
  },
  {
    "path": "dashboard/src/routes/traces/[taskId]/+page.svelte",
    "content": "<script lang=\"ts\">\n  import { page } from \"$app/stores\";\n  import { onMount } from \"svelte\";\n  import {\n    fetchTraceStats,\n    getTraceRawUrl,\n    type TraceStatsResponse,\n    type TraceCategoryStats,\n  } from \"$lib/stores/app.svelte\";\n  import HeaderNav from \"$lib/components/HeaderNav.svelte\";\n\n  const taskId = $derived($page.params.taskId);\n\n  let stats = $state<TraceStatsResponse | null>(null);\n  let loading = $state(true);\n  let error = $state<string | null>(null);\n\n  function formatDuration(us: number): string {\n    if (us < 1000) return `${us.toFixed(0)}us`;\n    if (us < 1_000_000) return `${(us / 1000).toFixed(2)}ms`;\n    return `${(us / 1_000_000).toFixed(2)}s`;\n  }\n\n  function formatPercentage(part: number, total: number): string {\n    if (total === 0) return \"0.0%\";\n    return `${((part / total) * 100).toFixed(1)}%`;\n  }\n\n  // Parse hierarchical categories like \"sync/compute\" into phases\n  type PhaseData = {\n    name: string;\n    subcategories: { name: string; stats: TraceCategoryStats }[];\n    totalUs: number; // From outer span (e.g., \"sync\" category)\n    stepCount: number; // Count of outer span events\n  };\n\n  function parsePhases(\n    byCategory: Record<string, TraceCategoryStats>,\n  ): PhaseData[] {\n    const phases = new Map<\n      string,\n      {\n        subcats: Map<string, TraceCategoryStats>;\n        outerStats: TraceCategoryStats | null;\n      }\n    >();\n\n    for (const [category, catStats] of Object.entries(byCategory)) {\n      if (category.includes(\"/\")) {\n        const [phase, subcat] = category.split(\"/\", 2);\n        if (!phases.has(phase)) {\n          phases.set(phase, { subcats: new Map(), outerStats: null });\n        }\n        phases.get(phase)!.subcats.set(subcat, catStats);\n      } else {\n        // Outer span - this IS the phase total\n        if (!phases.has(category)) {\n          phases.set(category, { subcats: new Map(), outerStats: null });\n        }\n        phases.get(category)!.outerStats = catStats;\n      }\n    }\n\n    return Array.from(phases.entries())\n      .filter(([_, data]) => data.outerStats !== null) // Only phases with outer spans\n      .map(([name, data]) => ({\n        name,\n        subcategories: Array.from(data.subcats.entries())\n          .map(([subName, subStats]) => ({ name: subName, stats: subStats }))\n          .sort((a, b) => b.stats.totalUs - a.stats.totalUs),\n        totalUs: data.outerStats!.totalUs, // Outer span total\n        stepCount: data.outerStats!.count, // Number of steps\n      }))\n      .sort((a, b) => b.totalUs - a.totalUs);\n  }\n\n  async function downloadTrace() {\n    if (!taskId) return;\n    const response = await fetch(getTraceRawUrl(taskId));\n    const blob = await response.blob();\n    const url = URL.createObjectURL(blob);\n    const a = document.createElement(\"a\");\n    a.href = url;\n    a.download = `trace_${taskId}.json`;\n    a.click();\n    URL.revokeObjectURL(url);\n  }\n\n  async function openInPerfetto() {\n    if (!taskId) return;\n\n    // Fetch trace data from our local API\n    const response = await fetch(getTraceRawUrl(taskId));\n    const traceData = await response.arrayBuffer();\n\n    // Open Perfetto UI\n    const perfettoWindow = window.open(\"https://ui.perfetto.dev\");\n    if (!perfettoWindow) {\n      alert(\"Failed to open Perfetto. Please allow popups.\");\n      return;\n    }\n\n    // Wait for Perfetto to be ready, then send trace via postMessage\n    const onMessage = (e: MessageEvent) => {\n      if (e.data === \"PONG\") {\n        window.removeEventListener(\"message\", onMessage);\n        perfettoWindow.postMessage(\n          {\n            perfetto: {\n              buffer: traceData,\n              title: `Trace ${taskId}`,\n            },\n          },\n          \"https://ui.perfetto.dev\",\n        );\n      }\n    };\n    window.addEventListener(\"message\", onMessage);\n\n    // Ping Perfetto until it responds\n    const pingInterval = setInterval(() => {\n      perfettoWindow.postMessage(\"PING\", \"https://ui.perfetto.dev\");\n    }, 50);\n\n    // Clean up after 10 seconds\n    setTimeout(() => {\n      clearInterval(pingInterval);\n      window.removeEventListener(\"message\", onMessage);\n    }, 10000);\n  }\n\n  onMount(async () => {\n    if (!taskId) {\n      error = \"No task ID provided\";\n      loading = false;\n      return;\n    }\n\n    try {\n      stats = await fetchTraceStats(taskId);\n    } catch (e) {\n      error = e instanceof Error ? e.message : \"Failed to load trace\";\n    } finally {\n      loading = false;\n    }\n  });\n\n  const phases = $derived(stats ? parsePhases(stats.byCategory) : []);\n  const sortedRanks = $derived(\n    stats\n      ? Object.keys(stats.byRank)\n          .map(Number)\n          .sort((a, b) => a - b)\n      : [],\n  );\n  const nodeCount = $derived(sortedRanks.length || 1);\n</script>\n\n<div class=\"min-h-screen bg-exo-dark-gray text-white\">\n  <HeaderNav showHome={true} />\n  <div class=\"max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6\">\n    <div class=\"flex items-center justify-between gap-4 flex-wrap\">\n      <div>\n        <h1\n          class=\"text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow\"\n        >\n          Trace\n        </h1>\n        <p class=\"text-sm text-exo-light-gray font-mono truncate max-w-lg\">\n          {taskId}\n        </p>\n      </div>\n      <div class=\"flex items-center gap-3\">\n        <a\n          href=\"#/traces\"\n          class=\"text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded\"\n        >\n          All Traces\n        </a>\n        <button\n          type=\"button\"\n          class=\"text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded\"\n          onclick={downloadTrace}\n          disabled={loading || !!error}\n        >\n          Download\n        </button>\n        <button\n          type=\"button\"\n          class=\"text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-3 py-1.5 rounded font-semibold\"\n          onclick={openInPerfetto}\n          disabled={loading || !!error}\n        >\n          View Trace\n        </button>\n      </div>\n    </div>\n\n    {#if loading}\n      <div\n        class=\"rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray\"\n      >\n        <div class=\"text-sm\">Loading trace data...</div>\n      </div>\n    {:else if error}\n      <div\n        class=\"rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400\"\n      >\n        <div class=\"text-sm\">{error}</div>\n      </div>\n    {:else if stats}\n      <!-- Wall Time Summary -->\n      <div\n        class=\"rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-2\"\n      >\n        <h2\n          class=\"text-sm font-mono uppercase tracking-wider text-exo-light-gray\"\n        >\n          Summary\n        </h2>\n        <div class=\"text-3xl font-mono text-exo-yellow\">\n          {formatDuration(stats.totalWallTimeUs)}\n        </div>\n        <div class=\"text-xs text-exo-light-gray\">Total wall time</div>\n      </div>\n\n      <!-- By Phase -->\n      {#if phases.length > 0}\n        <div\n          class=\"rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4\"\n        >\n          <h2\n            class=\"text-sm font-mono uppercase tracking-wider text-exo-light-gray\"\n          >\n            By Phase <span class=\"text-exo-light-gray/50\">(avg per node)</span>\n          </h2>\n          <div class=\"space-y-4\">\n            {#each phases as phase}\n              {@const normalizedTotal = phase.totalUs / nodeCount}\n              {@const normalizedStepCount = phase.stepCount / nodeCount}\n              <div class=\"space-y-2\">\n                <div class=\"flex items-center justify-between\">\n                  <span class=\"text-sm font-mono text-white\">{phase.name}</span>\n                  <span class=\"text-sm font-mono\">\n                    <span class=\"text-exo-yellow\"\n                      >{formatDuration(normalizedTotal)}</span\n                    >\n                    <span class=\"text-exo-light-gray ml-2\">\n                      ({normalizedStepCount} steps, {formatDuration(\n                        normalizedTotal / normalizedStepCount,\n                      )}/step)\n                    </span>\n                  </span>\n                </div>\n                {#if phase.subcategories.length > 0}\n                  <div class=\"pl-4 space-y-1.5\">\n                    {#each phase.subcategories as subcat}\n                      {@const normalizedSubcat =\n                        subcat.stats.totalUs / nodeCount}\n                      {@const pct = formatPercentage(\n                        normalizedSubcat,\n                        normalizedTotal,\n                      )}\n                      {@const perStep = normalizedSubcat / normalizedStepCount}\n                      <div\n                        class=\"flex items-center justify-between text-xs font-mono\"\n                      >\n                        <span class=\"text-exo-light-gray\">{subcat.name}</span>\n                        <span class=\"text-white\">\n                          {formatDuration(normalizedSubcat)}\n                          <span class=\"text-exo-light-gray ml-2\">({pct})</span>\n                          <span class=\"text-exo-light-gray/60 ml-2\"\n                            >{formatDuration(perStep)}/step</span\n                          >\n                        </span>\n                      </div>\n                      <!-- Progress bar -->\n                      <div\n                        class=\"relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden\"\n                      >\n                        <div\n                          class=\"absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300\"\n                          style=\"width: {pct}\"\n                        ></div>\n                      </div>\n                    {/each}\n                  </div>\n                {/if}\n              </div>\n            {/each}\n          </div>\n        </div>\n      {/if}\n\n      <!-- By Rank -->\n      {#if sortedRanks.length > 0}\n        <div\n          class=\"rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4\"\n        >\n          <h2\n            class=\"text-sm font-mono uppercase tracking-wider text-exo-light-gray\"\n          >\n            By Rank\n          </h2>\n          <div class=\"grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4\">\n            {#each sortedRanks as rank}\n              {@const rankStats = stats.byRank[rank]}\n              {@const rankPhases = parsePhases(rankStats.byCategory)}\n              <div\n                class=\"rounded border border-exo-medium-gray/20 bg-exo-dark-gray/60 p-3 space-y-3\"\n              >\n                <div class=\"text-sm font-mono text-exo-yellow\">\n                  Rank {rank}\n                </div>\n                <div class=\"space-y-2\">\n                  {#each rankPhases as phase}\n                    <div class=\"space-y-1\">\n                      <div class=\"flex items-center justify-between text-xs\">\n                        <span class=\"font-mono text-exo-light-gray\"\n                          >{phase.name}</span\n                        >\n                        <span class=\"font-mono text-white\">\n                          {formatDuration(phase.totalUs)}\n                          <span class=\"text-exo-light-gray/50 ml-1\">\n                            ({phase.stepCount}x)\n                          </span>\n                        </span>\n                      </div>\n                      {#if phase.subcategories.length > 0}\n                        <div class=\"pl-2 space-y-0.5\">\n                          {#each phase.subcategories as subcat}\n                            {@const pct = formatPercentage(\n                              subcat.stats.totalUs,\n                              phase.totalUs,\n                            )}\n                            {@const perStep =\n                              subcat.stats.totalUs / phase.stepCount}\n                            <div\n                              class=\"flex items-center justify-between text-[10px] font-mono\"\n                            >\n                              <span class=\"text-exo-light-gray/70\"\n                                >{subcat.name}</span\n                              >\n                              <span class=\"text-exo-light-gray\">\n                                {formatDuration(subcat.stats.totalUs)}\n                                <span class=\"text-exo-light-gray/50\"\n                                  >({pct})</span\n                                >\n                                <span class=\"text-exo-light-gray/30 ml-1\"\n                                  >{formatDuration(perStep)}/step</span\n                                >\n                              </span>\n                            </div>\n                          {/each}\n                        </div>\n                      {/if}\n                    </div>\n                  {/each}\n                </div>\n              </div>\n            {/each}\n          </div>\n        </div>\n      {/if}\n    {/if}\n  </div>\n</div>\n"
  },
  {
    "path": "dashboard/svelte.config.js",
    "content": "import adapter from '@sveltejs/adapter-static';\nimport { vitePreprocess } from '@sveltejs/vite-plugin-svelte';\n\n/** @type {import('@sveltejs/kit').Config} */\nconst config = {\n\tpreprocess: [vitePreprocess()],\n\n\tkit: {\n\t\tpaths: {\n\t\t\trelative: true\n\t\t},\n\t\trouter: { type: 'hash' },\n\t\tadapter: adapter({\n\t\t\tpages: 'build',\n\t\t\tassets: 'build',\n\t\t\tfallback: 'index.html',\n\t\t\tprecompress: false,\n\t\t\tstrict: true\n\t\t}),\n\t\talias: {\n\t\t\t$lib: 'src/lib',\n\t\t\t$components: 'src/lib/components'\n\t\t}\n\t}\n};\n\nexport default config;\n\n"
  },
  {
    "path": "dashboard/tsconfig.json",
    "content": "{\n\t\"extends\": \"./.svelte-kit/tsconfig.json\",\n\t\"compilerOptions\": {\n\t\t\"allowJs\": true,\n\t\t\"checkJs\": true,\n\t\t\"esModuleInterop\": true,\n\t\t\"forceConsistentCasingInFileNames\": true,\n\t\t\"resolveJsonModule\": true,\n\t\t\"skipLibCheck\": true,\n\t\t\"sourceMap\": true,\n\t\t\"strict\": true,\n\t\t\"moduleResolution\": \"bundler\"\n\t}\n}\n\n"
  },
  {
    "path": "dashboard/vite.config.ts",
    "content": "import tailwindcss from \"@tailwindcss/vite\";\nimport { sveltekit } from \"@sveltejs/kit/vite\";\nimport { defineConfig } from \"vite\";\n\nexport default defineConfig({\n  plugins: [tailwindcss(), sveltekit()],\n  server: {\n    proxy: {\n      \"/v1\": \"http://localhost:52415\",\n      \"/state\": \"http://localhost:52415\",\n      \"/models\": \"http://localhost:52415\",\n      \"/instance\": \"http://localhost:52415\",\n    },\n  },\n});\n"
  },
  {
    "path": "docs/api.md",
    "content": "# EXO API – Technical Reference\n\nThis document describes the REST API exposed by the **EXO** service, as implemented in:\n\n`src/exo/master/api.py`\n\nThe API is used to manage model instances in the cluster, inspect cluster state, and perform inference using multiple API-compatible interfaces.\n\nBase URL example:\n\n```\nhttp://localhost:52415\n```\n\n## 1. General / Meta Endpoints\n\n### Get Master Node ID\n\n**GET** `/node_id`\n\nReturns the identifier of the current master node.\n\n**Response (example):**\n\n```json\n{\n  \"node_id\": \"node-1234\"\n}\n```\n\n### Get Cluster State\n\n**GET** `/state`\n\nReturns the current state of the cluster, including nodes and active instances.\n\n**Response:**\nJSON object describing topology, nodes, and instances.\n\n### Get Events\n\n**GET** `/events`\n\nReturns the list of internal events recorded by the master (mainly for debugging and observability).\n\n**Response:**\nArray of event objects.\n\n## 2. Model Instance Management\n\n### Create Instance\n\n**POST** `/instance`\n\nCreates a new model instance in the cluster.\n\n**Request body (example):**\n\n```json\n{\n  \"instance\": {\n    \"model_id\": \"llama-3.2-1b\",\n    \"placement\": { }\n  }\n}\n```\n\n**Response:**\nJSON description of the created instance.\n\n### Delete Instance\n\n**DELETE** `/instance/{instance_id}`\n\nDeletes an existing instance by ID.\n\n**Path parameters:**\n\n* `instance_id`: string, ID of the instance to delete\n\n**Response:**\nStatus / confirmation JSON.\n\n### Get Instance\n\n**GET** `/instance/{instance_id}`\n\nReturns details of a specific instance.\n\n**Path parameters:**\n\n* `instance_id`: string\n\n**Response:**\nJSON description of the instance.\n\n### Preview Placements\n\n**GET** `/instance/previews?model_id=...`\n\nReturns possible placement previews for a given model.\n\n**Query parameters:**\n\n* `model_id`: string, required\n\n**Response:**\nArray of placement preview objects.\n\n### Compute Placement\n\n**GET** `/instance/placement`\n\nComputes a placement for a potential instance without creating it.\n\n**Query parameters (typical):**\n\n* `model_id`: string\n* `sharding`: string or config\n* `instance_meta`: JSON-encoded metadata\n* `min_nodes`: integer\n\n**Response:**\nJSON object describing the proposed placement / instance configuration.\n\n### Place Instance (Dry Operation)\n\n**POST** `/place_instance`\n\nPerforms a placement operation for an instance (planning step), without necessarily creating it.\n\n**Request body:**\nJSON describing the instance to be placed.\n\n**Response:**\nPlacement result.\n\n## 3. Models\n\n### List Models\n\n**GET** `/models`\n**GET** `/v1/models` (alias)\n\nReturns the list of available models and their metadata.\n\n**Query parameters:**\n\n* `status`: string (optional) - Filter by `downloaded` to show only downloaded models\n\n**Response:**\nArray of model descriptors including `is_custom` field for custom HuggingFace models.\n\n### Add Custom Model\n\n**POST** `/models/add`\n\nAdd a custom model from HuggingFace hub.\n\n**Request body (example):**\n\n```json\n{\n  \"model_id\": \"mlx-community/my-custom-model\"\n}\n```\n\n**Response:**\nModel descriptor for the added model.\n\n**Security note:**\nModels with `trust_remote_code` enabled in their configuration require explicit opt-in (default is false) for security.\n\n### Delete Custom Model\n\n**DELETE** `/models/custom/{model_id}`\n\nDelete a user-added custom model card.\n\n**Path parameters:**\n\n* `model_id`: string, ID of the custom model to delete\n\n**Response:**\nConfirmation JSON with deleted model ID.\n\n### Search Models\n\n**GET** `/models/search`\n\nSearch HuggingFace Hub for mlx-community models.\n\n**Query parameters:**\n\n* `query`: string (optional) - Search query\n* `limit`: integer (default: 20) - Maximum number of results\n\n**Response:**\nArray of HuggingFace model search results.\n\n## 4. Inference / Chat Completions\n\n### OpenAI-Compatible Chat Completions\n\n**POST** `/v1/chat/completions`\n\nExecutes a chat completion request using an OpenAI-compatible schema. Supports streaming and non-streaming modes.\n\n**Request body (example):**\n\n```json\n{\n  \"model\": \"llama-3.2-1b\",\n  \"messages\": [\n    { \"role\": \"system\", \"content\": \"You are a helpful assistant.\" },\n    { \"role\": \"user\", \"content\": \"Hello\" }\n  ],\n  \"stream\": false\n}\n```\n\n**Request parameters:**\n\n* `model`: string, required - Model ID to use\n* `messages`: array, required - Conversation messages\n* `stream`: boolean (default: false) - Enable streaming responses\n* `max_tokens`: integer (optional) - Maximum tokens to generate\n* `temperature`: float (optional) - Sampling temperature\n* `top_p`: float (optional) - Nucleus sampling parameter\n* `top_k`: integer (optional) - Top-k sampling parameter\n* `stop`: string or array (optional) - Stop sequences\n* `seed`: integer (optional) - Random seed for reproducibility\n* `enable_thinking`: boolean (optional) - Enable thinking mode for capable models (DeepSeek V3.1, Qwen3, GLM-4.7)\n* `tools`: array (optional) - Tool definitions for function calling\n* `logprobs`: boolean (optional) - Return log probabilities\n* `top_logprobs`: integer (optional) - Number of top log probabilities to return\n\n**Response:**\nOpenAI-compatible chat completion response.\n\n**Streaming response format:**\nWhen `stream=true`, returns Server-Sent Events (SSE) with format:\n\n```\ndata: {\"id\":\"...\",\"object\":\"chat.completion\",\"created\":...,\"model\":\"...\",\"choices\":[...]}\n\ndata: [DONE]\n```\n\n**Non-streaming response includes usage statistics:**\n\n```json\n{\n  \"id\": \"...\",\n  \"object\": \"chat.completion\",\n  \"created\": 1234567890,\n  \"model\": \"llama-3.2-1b\",\n  \"choices\": [{\n    \"index\": 0,\n    \"message\": {\n      \"role\": \"assistant\",\n      \"content\": \"Hello! How can I help you?\"\n    },\n    \"finish_reason\": \"stop\"\n  }],\n  \"usage\": {\n    \"prompt_tokens\": 15,\n    \"completion_tokens\": 8,\n    \"total_tokens\": 23\n  }\n}\n```\n\n**Cancellation:**\nYou can cancel an active generation by closing the HTTP connection. The server detects the disconnection and stops processing.\n\n### Claude Messages API\n\n**POST** `/v1/messages`\n\nExecutes a chat completion request using the Claude Messages API format. Supports streaming and non-streaming modes.\n\n**Request body (example):**\n\n```json\n{\n  \"model\": \"llama-3.2-1b\",\n  \"messages\": [\n    { \"role\": \"user\", \"content\": \"Hello\" }\n  ],\n  \"max_tokens\": 1024,\n  \"stream\": false\n}\n```\n\n**Streaming response format:**\nWhen `stream=true`, returns Server-Sent Events with Claude-specific event types:\n\n* `message_start` - Message generation started\n* `content_block_start` - Content block started\n* `content_block_delta` - Incremental content chunk\n* `content_block_stop` - Content block completed\n* `message_delta` - Message metadata updates\n* `message_stop` - Message generation completed\n\n**Response:**\nClaude-compatible messages response.\n\n### OpenAI Responses API\n\n**POST** `/v1/responses`\n\nExecutes a chat completion request using the OpenAI Responses API format. Supports streaming and non-streaming modes.\n\n**Request body (example):**\n\n```json\n{\n  \"model\": \"llama-3.2-1b\",\n  \"messages\": [\n    { \"role\": \"user\", \"content\": \"Hello\" }\n  ],\n  \"stream\": false\n}\n```\n\n**Streaming response format:**\nWhen `stream=true`, returns Server-Sent Events with response-specific event types:\n\n* `response.created` - Response generation started\n* `response.in_progress` - Response is being generated\n* `response.output_item.added` - New output item added\n* `response.output_item.done` - Output item completed\n* `response.done` - Response generation completed\n\n**Response:**\nOpenAI Responses API-compatible response.\n\n### Benchmarked Chat Completions\n\n**POST** `/bench/chat/completions`\n\nSame as `/v1/chat/completions`, but also returns performance and generation statistics.\n\n**Request body:**\nSame schema as `/v1/chat/completions`.\n\n**Response:**\nChat completion plus benchmarking metrics including:\n\n* `prompt_tps` - Tokens per second during prompt processing\n* `generation_tps` - Tokens per second during generation\n* `prompt_tokens` - Number of prompt tokens\n* `generation_tokens` - Number of generated tokens\n* `peak_memory_usage` - Peak memory used during generation\n\n### Cancel Command\n\n**POST** `/v1/cancel/{command_id}`\n\nCancels an active generation command (text or image). Notifies workers and closes the stream.\n\n**Path parameters:**\n\n* `command_id`: string, ID of the command to cancel\n\n**Response (example):**\n\n```json\n{\n  \"message\": \"Command cancelled.\",\n  \"command_id\": \"cmd-abc-123\"\n}\n```\n\nReturns 404 if the command is not found or already completed.\n\n## 5. Ollama API Compatibility\n\nEXO provides Ollama API compatibility for tools like OpenWebUI.\n\n### Ollama Chat\n\n**POST** `/ollama/api/chat`\n**POST** `/ollama/api/api/chat` (alias)\n**POST** `/ollama/api/v1/chat` (alias)\n\nExecute a chat request using Ollama API format.\n\n**Request body (example):**\n\n```json\n{\n  \"model\": \"llama-3.2-1b\",\n  \"messages\": [\n    { \"role\": \"user\", \"content\": \"Hello\" }\n  ],\n  \"stream\": false\n}\n```\n\n**Response:**\nOllama-compatible chat response.\n\n### Ollama Generate\n\n**POST** `/ollama/api/generate`\n\nExecute a text generation request using Ollama API format.\n\n**Request body (example):**\n\n```json\n{\n  \"model\": \"llama-3.2-1b\",\n  \"prompt\": \"Hello\",\n  \"stream\": false\n}\n```\n\n**Response:**\nOllama-compatible generation response.\n\n### Ollama Tags\n\n**GET** `/ollama/api/tags`\n**GET** `/ollama/api/api/tags` (alias)\n**GET** `/ollama/api/v1/tags` (alias)\n\nReturns list of downloaded models in Ollama tags format.\n\n**Response:**\nArray of model tags with metadata.\n\n### Ollama Show\n\n**POST** `/ollama/api/show`\n\nReturns model information in Ollama show format.\n\n**Request body:**\n\n```json\n{\n  \"name\": \"llama-3.2-1b\"\n}\n```\n\n**Response:**\nModel details including modelfile and family.\n\n### Ollama PS\n\n**GET** `/ollama/api/ps`\n\nReturns list of running models (active instances).\n\n**Response:**\nArray of active model instances.\n\n### Ollama Version\n\n**GET** `/ollama/api/version`\n**HEAD** `/ollama/` (alias)\n**HEAD** `/ollama/api/version` (alias)\n\nReturns version information for Ollama API compatibility.\n\n**Response:**\n\n```json\n{\n  \"version\": \"exo v1.0\"\n}\n```\n\n## 6. Image Generation & Editing\n\n### Image Generation\n\n**POST** `/v1/images/generations`\n\nExecutes an image generation request using an OpenAI-compatible schema with additional advanced_params. Supports both streaming and non-streaming modes.\n\n**Request body (example):**\n\n```json\n{\n  \"prompt\": \"a robot playing chess\",\n  \"model\": \"exolabs/FLUX.1-dev\",\n  \"n\": 1,\n  \"size\": \"1024x1024\",\n  \"stream\": false,\n  \"response_format\": \"b64_json\"\n}\n```\n\n**Request parameters:**\n\n* `prompt`: string, required - Text description of the image\n* `model`: string, required - Image model ID\n* `n`: integer (default: 1) - Number of images to generate\n* `size`: string (default: \"auto\") - Image dimensions. Supported sizes:\n  - `512x512`\n  - `768x768`\n  - `1024x768`\n  - `768x1024`\n  - `1024x1024`\n  - `1024x1536`\n  - `1536x1024`\n  - `1024x1365`\n  - `1365x1024`\n* `stream`: boolean (default: false) - Enable streaming for partial images\n* `partial_images`: integer (default: 0) - Number of partial images to stream during generation\n* `response_format`: string (default: \"b64_json\") - Either `url` or `b64_json`\n* `quality`: string (default: \"medium\") - Either `high`, `medium`, or `low`\n* `output_format`: string (default: \"png\") - Either `png`, `jpeg`, or `webp`\n* `advanced_params`: object (optional) - Advanced generation parameters\n\n**Advanced Parameters (`advanced_params`):**\n\n| Parameter | Type | Constraints | Description |\n|-----------|------|-------------|-------------|\n| `seed` | int | >= 0 | Random seed for reproducible generation |\n| `num_inference_steps` | int | 1-100 | Number of denoising steps |\n| `guidance` | float | 1.0-20.0 | Classifier-free guidance scale |\n| `negative_prompt` | string | - | Text describing what to avoid in the image |\n\n**Non-streaming response:**\n\n```json\n{\n  \"created\": 1234567890,\n  \"data\": [\n    {\n      \"b64_json\": \"iVBORw0KGgoAAAANSUhEUgAA...\",\n      \"url\": null\n    }\n  ]\n}\n```\n\n**Streaming response format:**\nWhen `stream=true` and `partial_images > 0`, returns Server-Sent Events:\n\n```\ndata: {\"type\":\"partial\",\"image_index\":0,\"partial_index\":1,\"total_partials\":5,\"format\":\"png\",\"data\":{\"b64_json\":\"...\"}}\n\ndata: {\"type\":\"final\",\"image_index\":0,\"format\":\"png\",\"data\":{\"b64_json\":\"...\"}}\n\ndata: [DONE]\n```\n\n### Image Editing\n\n**POST** `/v1/images/edits`\n\nExecutes an image editing request (img2img) using FLUX.1-Kontext-dev or similar models.\n\n**Request (multipart/form-data):**\n\n* `image`: file, required - Input image to edit\n* `prompt`: string, required - Text description of desired changes\n* `model`: string, required - Image editing model ID (e.g., `exolabs/FLUX.1-Kontext-dev`)\n* `n`: integer (default: 1) - Number of edited images to generate\n* `size`: string (optional) - Output image dimensions\n* `response_format`: string (default: \"b64_json\") - Either `url` or `b64_json`\n* `input_fidelity`: string (default: \"low\") - Either `low` or `high` - Controls how closely the output follows the input image\n* `stream`: string (default: \"false\") - Enable streaming\n* `partial_images`: string (default: \"0\") - Number of partial images to stream\n* `quality`: string (default: \"medium\") - Either `high`, `medium`, or `low`\n* `output_format`: string (default: \"png\") - Either `png`, `jpeg`, or `webp`\n* `advanced_params`: string (optional) - JSON-encoded advanced parameters\n\n**Response:**\nSame format as `/v1/images/generations`.\n\n### Benchmarked Image Generation\n\n**POST** `/bench/images/generations`\n\nSame as `/v1/images/generations`, but also returns generation statistics.\n\n**Request body:**\nSame schema as `/v1/images/generations`.\n\n**Response:**\nImage generation plus benchmarking metrics including:\n\n* `seconds_per_step` - Average time per denoising step\n* `total_generation_time` - Total generation time\n* `num_inference_steps` - Number of inference steps used\n* `num_images` - Number of images generated\n* `image_width` - Output image width\n* `image_height` - Output image height\n* `peak_memory_usage` - Peak memory used during generation\n\n### Benchmarked Image Editing\n\n**POST** `/bench/images/edits`\n\nSame as `/v1/images/edits`, but also returns generation statistics.\n\n**Request:**\nSame schema as `/v1/images/edits`.\n\n**Response:**\nSame format as `/bench/images/generations`, including `generation_stats`.\n\n### List Images\n\n**GET** `/images`\n\nList all stored images.\n\n**Response:**\nArray of image metadata including URLs and expiration times.\n\n### Get Image\n\n**GET** `/images/{image_id}`\n\nRetrieve a stored image by ID.\n\n**Path parameters:**\n\n* `image_id`: string, ID of the image\n\n**Response:**\nImage file with appropriate content type.\n\n## 7. Complete Endpoint Summary\n\n```\n# General\nGET     /node_id\nGET     /state\nGET     /events\n\n# Instance Management\nPOST    /instance\nGET     /instance/{instance_id}\nDELETE  /instance/{instance_id}\nGET     /instance/previews\nGET     /instance/placement\nPOST    /place_instance\n\n# Models\nGET     /models\nGET     /v1/models\nPOST    /models/add\nDELETE  /models/custom/{model_id}\nGET     /models/search\n\n# Text Generation (OpenAI Chat Completions)\nPOST    /v1/chat/completions\nPOST    /bench/chat/completions\n\n# Text Generation (Claude Messages API)\nPOST    /v1/messages\n\n# Text Generation (OpenAI Responses API)\nPOST    /v1/responses\n\n# Text Generation (Ollama API)\nPOST    /ollama/api/chat\nPOST    /ollama/api/api/chat\nPOST    /ollama/api/v1/chat\nPOST    /ollama/api/generate\nGET     /ollama/api/tags\nGET     /ollama/api/api/tags\nGET     /ollama/api/v1/tags\nPOST    /ollama/api/show\nGET     /ollama/api/ps\nGET     /ollama/api/version\nHEAD    /ollama/\nHEAD    /ollama/api/version\n\n# Command Control\nPOST    /v1/cancel/{command_id}\n\n# Image Generation\nPOST    /v1/images/generations\nPOST    /bench/images/generations\nPOST    /v1/images/edits\nPOST    /bench/images/edits\nGET     /images\nGET     /images/{image_id}\n```\n\n## 8. Notes\n\n### API Compatibility\n\nEXO provides multiple API-compatible interfaces:\n\n* **OpenAI Chat Completions API** - Compatible with OpenAI clients and tools\n* **Claude Messages API** - Compatible with Anthropic's Claude API format\n* **OpenAI Responses API** - Compatible with OpenAI's Responses API format\n* **Ollama API** - Compatible with Ollama and tools like OpenWebUI\n\nExisting OpenAI, Claude, or Ollama clients can be pointed to EXO by changing the base URL.\n\n### Custom Models\n\nYou can add custom models from HuggingFace using the `/models/add` endpoint. Custom models are identified by the `is_custom` field in model list responses.\n\n**Security:** Models requiring `trust_remote_code` must be explicitly enabled (default is false) for security. Only enable this if you trust the model's remote code.\n\n### Usage Statistics\n\nChat completion responses include usage statistics with:\n\n* `prompt_tokens` - Number of tokens in the prompt\n* `completion_tokens` - Number of tokens generated\n* `total_tokens` - Sum of prompt and completion tokens\n\n### Request Cancellation\n\nYou can cancel active requests by:\n\n1. Closing the HTTP connection (for streaming requests)\n2. Calling `/v1/cancel/{command_id}` (for any request)\n\nThe server detects cancellation and stops processing immediately.\n\n### Instance Placement\n\nThe instance placement endpoints allow you to plan and preview cluster allocations before creating instances. This helps optimize resource usage across nodes.\n\n### Observability\n\nThe `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.\n"
  },
  {
    "path": "docs/architecture.md",
    "content": "# EXO Architecture overview\n\nEXO uses an _Event Sourcing_ architecture, and Erlang-style _message passing_. To facilitate this, we've written a channel library extending anyio channels with inspiration from tokio::sync::mpsc. \n\nEach logical module - designed to be functional independently of the others - communicates with the rest of the system by sending messages on topics.\n\n## Systems\n\nThere are currently 5 major systems:\n\n- Master\n    \n    Executes placement and orders events through a single writer\n\n- Worker\n    \n    Schedules work on a node, gathers system information, etc.#\n\n- Runner\n    \n    Executes inference jobs (for now) in an isolated process from the worker for fault-tolerance.\n\n- API\n    \n    Runs a python webserver for exposing state and commands to client applications\n\n- Election\n    \n    Implements a distributed algorithm for master election in unstable networking conditions\n\n## API Layer\n\nThe API system uses multiple adapters to support multiple API formats, converting them to a single request / response type.\n\n### Adapter Pattern\n\nAdapters convert between external API formats and EXO's internal types:\n\n```\nChat Completions → [adapter] → TextGenerationTaskParams → Application\nClaude Messages  → [adapter] → TextGenerationTaskParams → Application\nResponses API    → [adapter] → TextGenerationTaskParams → Application\nOllama API       → [adapter] → TextGenerationTaskParams → Application\n```\n\nEach adapter implements two key functions:\n1. **Request conversion**: Converts API-specific requests to `TextGenerationTaskParams`\n2. **Response generation**: Converts internal `TokenChunk` streams back to API-specific formats (streaming and non-streaming)\n\n\n## Topics\n\nThere are currently 5 topics:\n\n- Commands\n\n    The API and Worker instruct the master when the event log isn't sufficient. Namely placement and catchup requests go through Commands atm.\n\n- Local Events\n\n    All nodes write events here, the master reads those events and orders them\n\n- Global Events\n\n    The master writes events here, all nodes read from this topic and fold the produced events into their `State`\n\n- Election Messages\n\n    Before establishing a cluster, nodes communicate here to negotiate a master node.\n\n- Connection Messages\n\n    The networking system write mdns-discovered hardware connections here.\n\n\n## Event Sourcing\n\nLots has been written about event sourcing, but it lets us centralize faulty connections and message ACKing with the following model.\n\nWhenever a device produces side effects, it captures those side effects in an `Event`. `Event`s are then \"applied\" to their model of `State`, which is globally distributed across the cluster. Whenever a command is received, it is combined with state to produce side effects, captured in yet more events. The rule of thumb is \"`Event`s are past tense, `Command`s are imperative\". Telling a node to perform some action like \"place this model\" or \"Give me a copy of the event log\" is represented by a command (The worker's `Task`s are also commands), while \"this node is using 300GB of ram\" is an event. Notably, `Event`s SHOULD never cause side effects on their own. There are a few exceptions to this, we're working out the specifics of generalizing the distributed event sourcing model to make it better suit our needs\n\n## Purity\n\nA significant goal of the current design is to make data flow explicit. Classes should either represent simple data (`CamelCaseModel`s typically, and `TaggedModel`s for unions) or active `System`s (Erlang `Actor`s), with all transformations of that data being \"referentially transparent\" - destructure and construct new data, don't mutate in place. We have had varying degrees of success with this, and are still exploring where purity makes sense.\n"
  },
  {
    "path": "flake.nix",
    "content": "{\n  description = \"The development environment for Exo\";\n\n  inputs = {\n    nixpkgs.url = \"github:NixOS/nixpkgs/nixos-unstable\";\n\n    flake-parts = {\n      url = \"github:hercules-ci/flake-parts\";\n      inputs.nixpkgs-lib.follows = \"nixpkgs\";\n    };\n\n    crane.url = \"github:ipetkov/crane\";\n\n    fenix = {\n      url = \"github:nix-community/fenix\";\n      inputs.nixpkgs.follows = \"nixpkgs\";\n    };\n\n    treefmt-nix = {\n      url = \"github:numtide/treefmt-nix\";\n      inputs.nixpkgs.follows = \"nixpkgs\";\n    };\n\n    dream2nix = {\n      url = \"github:nix-community/dream2nix\";\n      inputs.nixpkgs.follows = \"nixpkgs\";\n      inputs.pyproject-nix.follows = \"pyproject-nix\";\n    };\n\n    # Python packaging with uv2nix\n    pyproject-nix = {\n      url = \"github:pyproject-nix/pyproject.nix\";\n      inputs.nixpkgs.follows = \"nixpkgs\";\n    };\n\n    uv2nix = {\n      url = \"github:pyproject-nix/uv2nix\";\n      inputs.pyproject-nix.follows = \"pyproject-nix\";\n      inputs.nixpkgs.follows = \"nixpkgs\";\n    };\n\n    pyproject-build-systems = {\n      url = \"github:pyproject-nix/build-system-pkgs\";\n      inputs.pyproject-nix.follows = \"pyproject-nix\";\n      inputs.uv2nix.follows = \"uv2nix\";\n      inputs.nixpkgs.follows = \"nixpkgs\";\n    };\n\n    # Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)\n    nixpkgs-swift.url = \"github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c\";\n  };\n\n  nixConfig = {\n    extra-trusted-public-keys = \"exo.cachix.org-1:okq7hl624TBeAR3kV+g39dUFSiaZgLRkLsFBCuJ2NZI=\";\n    extra-substituters = \"https://exo.cachix.org\";\n  };\n\n  outputs =\n    inputs:\n    inputs.flake-parts.lib.mkFlake { inherit inputs; } {\n      systems = [\n        \"x86_64-linux\"\n        \"aarch64-darwin\"\n        \"aarch64-linux\"\n      ];\n\n      imports = [\n        inputs.treefmt-nix.flakeModule\n        ./dashboard/parts.nix\n        ./rust/parts.nix\n        ./python/parts.nix\n      ];\n\n      perSystem =\n        { config, self', inputs', pkgs, lib, system, ... }:\n        let\n          # Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)\n          pkgsSwift = import inputs.nixpkgs-swift { inherit system; };\n        in\n        {\n          # Allow unfree for metal-toolchain (needed for Darwin Metal packages)\n          _module.args.pkgs = import inputs.nixpkgs {\n            inherit system;\n            config.allowUnfreePredicate = pkg: (pkg.pname or \"\") == \"metal-toolchain\";\n            overlays = [\n              (import ./nix/apple-sdk-overlay.nix)\n            ];\n          };\n          treefmt = {\n            projectRootFile = \"flake.nix\";\n            programs = {\n              nixpkgs-fmt.enable = true;\n              ruff-format = {\n                enable = true;\n                excludes = [ \"rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi\" ];\n              };\n              rustfmt = {\n                enable = true;\n                package = config.rust.toolchain;\n              };\n              prettier = {\n                enable = true;\n                package = self'.packages.prettier-svelte;\n                includes = [ \"*.ts\" \"*.svelte\" ];\n              };\n              swift-format = {\n                enable = true;\n                package = pkgsSwift.swiftPackages.swift-format;\n              };\n              shfmt.enable = true;\n              taplo.enable = true;\n            };\n          };\n\n          packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (\n            let\n              uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);\n              mlxPackage = builtins.head (builtins.filter (p: p.name == \"mlx\" && p.source ? git) uvLock.package);\n              uvLockMlxVersion = mlxPackage.version;\n              uvLockMlxRev = builtins.elemAt (builtins.split \"#\" mlxPackage.source.git) 2;\n            in\n            {\n              metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };\n              mlx = pkgs.callPackage ./nix/mlx.nix {\n                inherit (self'.packages) metal-toolchain;\n                inherit uvLockMlxVersion uvLockMlxRev;\n              };\n              default = self'.packages.exo;\n            }\n          );\n\n          devShells.default = with pkgs; pkgs.mkShell {\n            inputsFrom = [ self'.checks.cargo-build ];\n\n            packages =\n              [\n                # FORMATTING\n                config.treefmt.build.wrapper\n\n                # PYTHON\n                python313\n                uv\n                ruff\n                basedpyright\n\n                # RUST\n                config.rust.toolchain\n                maturin\n\n                # NIX\n                nixpkgs-fmt\n\n                # SVELTE\n                nodejs\n\n                # MISC\n                just\n                jq\n              ]\n              ++ lib.optionals stdenv.isLinux [\n                unixtools.ifconfig\n              ]\n              ++ lib.optionals stdenv.isDarwin [\n                macmon\n              ];\n\n            OPENSSL_NO_VENDOR = \"1\";\n\n            shellHook = ''\n              export LD_LIBRARY_PATH=\"$LD_LIBRARY_PATH:${python313}/lib\"\n              ${lib.optionalString stdenv.isLinux ''\n                export LD_LIBRARY_PATH=\"${openssl.out}/lib:$LD_LIBRARY_PATH\"\n              ''}\n            '';\n          };\n        };\n    };\n}\n"
  },
  {
    "path": "justfile",
    "content": "export NIX_CONFIG := \"extra-experimental-features = nix-command flakes\"\n\nfmt:\n    treefmt || nix fmt\n\nlint:\n    uv run ruff check --fix\n\ntest:\n    uv run pytest src\n\ncheck:\n    uv run basedpyright --project pyproject.toml\n\nsync:\n    uv sync --all-packages\n\nsync-clean:\n    uv sync --all-packages --force-reinstall --no-cache\n\nrust-rebuild:\n    cargo run --bin stub_gen\n    uv sync --reinstall-package exo_pyo3_bindings\n\nbuild-dashboard:\n    #!/usr/bin/env bash\n    cd dashboard\n    npm install\n    npm run build\n\npackage:\n    uv run pyinstaller packaging/pyinstaller/exo.spec\n\nclean:\n    rm -rf **/__pycache__\n    rm -rf target/\n    rm -rf .venv\n    rm -rf dashboard/node_modules\n    rm -rf dashboard/.svelte-kit\n    rm -rf dashboard/build\n"
  },
  {
    "path": "nix/apple-sdk/metadata/versions.json",
    "content": "{\n  \"14\": {\n    \"urls\": [\n      \"https://swcdn.apple.com/content/downloads/14/48/052-59890-A_I0F5YGAY0Y/p9n40hio7892gou31o1v031ng6fnm9sb3c/CLTools_macOSNMOS_SDK.pkg\",\n      \"https://web.archive.org/web/20250211001355/https://swcdn.apple.com/content/downloads/14/48/052-59890-A_I0F5YGAY0Y/p9n40hio7892gou31o1v031ng6fnm9sb3c/CLTools_macOSNMOS_SDK.pkg\"\n    ],\n    \"version\": \"14.4\",\n    \"hash\": \"sha256-QozDiwY0Czc0g45vPD7G4v4Ra+3DujCJbSads3fJjjM=\"\n  },\n  \"15\": {\n    \"urls\": [\n      \"https://swcdn.apple.com/content/downloads/52/01/082-41241-A_0747ZN8FHV/dectd075r63pppkkzsb75qk61s0lfee22j/CLTools_macOSNMOS_SDK.pkg\",\n      \"https://web.archive.org/web/20250530132510/https://swcdn.apple.com/content/downloads/52/01/082-41241-A_0747ZN8FHV/dectd075r63pppkkzsb75qk61s0lfee22j/CLTools_macOSNMOS_SDK.pkg\"\n    ],\n    \"version\": \"15.5\",\n    \"hash\": \"sha256-HBiSJuw1XBUK5R/8Sj65c3rftSEvQl/O9ZZVp/g1Amo=\"\n  },\n  \"26\": {\n    \"urls\": [\n      \"https://swcdn.apple.com/content/downloads/60/22/089-71960-A_W8BL1RUJJ6/5zkyplomhk1cm7z6xja2ktgapnhhti6wwd/CLTools_macOSNMOS_SDK.pkg\",\n      \"https://web.archive.org/web/20250915230423/https://swcdn.apple.com/content/downloads/60/22/089-71960-A_W8BL1RUJJ6/5zkyplomhk1cm7z6xja2ktgapnhhti6wwd/CLTools_macOSNMOS_SDK.pkg\"\n    ],\n    \"version\": \"26.2\",\n    \"hash\": \"sha256-hXRlMieVv0smna5uiWRwq87IWOaPWtAjAldbi+wQXcw=\"\n  }\n}\n"
  },
  {
    "path": "nix/apple-sdk-overlay.nix",
    "content": "# Overlay that builds apple-sdk with a custom versions.json (for SDK 26.2).\n# The upstream nixpkgs package reads versions.json at eval time via a relative\n# path, so we can't override it through callPackage args. Instead, we copy\n# the upstream source and patch the one file.\nfinal: _prev:\nlet\n  upstreamSrc = final.path + \"/pkgs/by-name/ap/apple-sdk\";\n  patchedSrc = final.runCommandLocal \"apple-sdk-src-patched\" { } ''\n    cp -r ${upstreamSrc} $out\n    chmod -R u+w $out\n    cp ${./apple-sdk/metadata/versions.json} $out/metadata/versions.json\n  '';\nin\n{\n  apple-sdk_26 = final.callPackage (patchedSrc + \"/package.nix\") {\n    darwinSdkMajorVersion = \"26\";\n  };\n}\n"
  },
  {
    "path": "nix/darwin-build-fixes.patch",
    "content": "diff --git a/CMakeLists.txt b/CMakeLists.txt\nindex 0ed30932..d8528132 100644\n--- a/CMakeLists.txt\n+++ b/CMakeLists.txt\n@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)\n     add_compile_definitions(MLX_METAL_DEBUG)\n   endif()\n\n-  # Throw an error if xcrun not found\n-  execute_process(\n-    COMMAND zsh \"-c\" \"/usr/bin/xcrun -sdk macosx --show-sdk-version\"\n-    OUTPUT_VARIABLE MACOS_SDK_VERSION\n-    OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)\n+  set(MACOS_SDK_VERSION @sdkVersion@)\n\n   if(${MACOS_SDK_VERSION} LESS 14.0)\n     message(\n@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)\n     endif()\n     set(XCRUN_FLAGS \"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}\")\n   endif()\n-  execute_process(\n-    COMMAND\n-      zsh \"-c\"\n-      \"echo \\\"__METAL_VERSION__\\\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\\n'\"\n-    OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)\n+  set(\n+    MLX_METAL_VERSION @metalVersion@)\n   FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})\n   FetchContent_MakeAvailable(metal_cpp)\n   target_include_directories(\ndiff --git a/cmake/extension.cmake b/cmake/extension.cmake\nindex 13db804a..5b385132 100644\n--- a/cmake/extension.cmake\n+++ b/cmake/extension.cmake\n@@ -36,7 +36,7 @@ macro(mlx_build_metallib)\n   add_custom_command(\n     OUTPUT ${MTLLIB_BUILD_TARGET}\n     COMMAND\n-      xcrun -sdk macosx metal\n+      metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache\n       \"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>\"\n       ${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}\n     DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}\ndiff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt\nindex 262b0495..5c7446ad 100644\n--- a/mlx/backend/metal/kernels/CMakeLists.txt\n+++ b/mlx/backend/metal/kernels/CMakeLists.txt\n@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)\n                     \"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}\")\n   endif()\n   add_custom_command(\n-    COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}\n+    COMMAND metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache ${METAL_FLAGS} -c ${SRCFILE}\n             -I${PROJECT_SOURCE_DIR} -o ${TARGET}.air\n     DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}\n     OUTPUT ${TARGET}.air\n@@ -170,7 +170,7 @@ endif()\n\n add_custom_command(\n   OUTPUT ${MLX_METAL_PATH}/mlx.metallib\n-  COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o\n+  COMMAND metallib ${KERNEL_AIR} -o\n           ${MLX_METAL_PATH}/mlx.metallib\n   DEPENDS ${KERNEL_AIR}\n   COMMENT \"Building mlx.metallib\"\ndiff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh\nindex bb55ed3a..94ea7dd7 100644\n--- a/mlx/backend/metal/make_compiled_preamble.sh\n+++ b/mlx/backend/metal/make_compiled_preamble.sh\n@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp\n mkdir -p \"$OUTPUT_DIR\"\n\n # Use the metal compiler to get a list of headers (with depth)\n-CCC=\"xcrun -sdk macosx metal -x metal\"\n+CCC=\"metal -x metal -fmodules-cache-path=${OUTPUT_DIR}/metal-cache\"\n HDRS=$( $CCC -I\"$SRC_DIR\" -I\"$JIT_INCLUDES\" -DMLX_METAL_JIT -E -P -CC -C -H \"$INPUT_FILE\" $CFLAGS -w 2>&1 1>/dev/null )\n\n # Remove any included system frameworks (for MetalPerformancePrimitive headers)\n"
  },
  {
    "path": "nix/metal-toolchain.nix",
    "content": "{ lib, stdenvNoCC, requireFile, nix }:\n\nlet\n  narFile = requireFile {\n    name = \"metal-toolchain-17C48.nar\";\n    message = ''\n      The Metal Toolchain NAR must be available.\n\n      If you have cachix configured for exo.cachix.org, this should be automatic.\n\n      Otherwise:\n        1. Install Xcode 26+ from the App Store\n        2. Run: xcodebuild -downloadComponent MetalToolchain\n        3. Export the toolchain:\n           hdiutil attach \"$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' | head -1)\" -mountpoint /tmp/metal-dmg\n           cp -R /tmp/metal-dmg/Metal.xctoolchain /tmp/metal-export\n           hdiutil detach /tmp/metal-dmg\n        4. Create NAR and add to store:\n           nix nar pack /tmp/metal-export > /tmp/metal-toolchain-17C48.nar\n           nix store add --mode flat /tmp/metal-toolchain-17C48.nar\n    '';\n    hash = \"sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw=\";\n  };\nin\nstdenvNoCC.mkDerivation {\n  pname = \"metal-toolchain\";\n  version = \"17C48\";\n\n  dontUnpack = true;\n  dontBuild = true;\n  dontFixup = true;\n\n  nativeBuildInputs = [ nix ];\n\n  installPhase = ''\n    runHook preInstall\n\n    nix-store --restore $out < ${narFile}\n\n    # Create bin directory with symlinks for PATH\n    mkdir -p $out/bin\n    ln -s $out/usr/bin/metal $out/bin/metal\n    ln -s $out/usr/bin/metallib $out/bin/metallib\n\n    runHook postInstall\n  '';\n\n  # Metal language version for CMake (from: echo __METAL_VERSION__ | metal -E -x metal -P -)\n  passthru.metalVersion = \"400\";\n\n  meta = {\n    description = \"Apple Metal compiler toolchain\";\n    platforms = [ \"aarch64-darwin\" ];\n    license = lib.licenses.unfree;\n  };\n}\n"
  },
  {
    "path": "nix/mlx.nix",
    "content": "{ stdenv\n, lib\n, fetchFromGitHub\n, replaceVars\n, fetchzip\n, cmake\n, nlohmann_json\n, apple-sdk_26\n, metal-toolchain\n, runCommand\n, fmt\n, python313Packages\n, uvLockMlxVersion\n, uvLockMlxRev\n}:\n\nassert stdenv.isDarwin;\n\nlet\n  python = python313Packages.python;\n\n  # Static dependencies included directly during compilation\n  gguf-tools = fetchFromGitHub {\n    owner = \"antirez\";\n    repo = \"gguf-tools\";\n    rev = \"8fa6eb65236618e28fd7710a0fba565f7faa1848\";\n    hash = \"sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=\";\n  };\n\n  metal_cpp = fetchzip {\n    url = \"https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip\";\n    hash = \"sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=\";\n  };\n\n  nanobind = fetchFromGitHub {\n    owner = \"wjakob\";\n    repo = \"nanobind\";\n    rev = \"v2.10.2\";\n    hash = \"sha256-io44YhN+VpfHFWyvvLWSanRgbzA0whK8WlDNRi3hahU=\";\n    fetchSubmodules = true;\n  };\n\n  mlx = stdenv.mkDerivation rec {\n    pname = \"mlx\";\n    version = uvLockMlxVersion;\n    pyproject = true;\n\n    src = fetchFromGitHub {\n      owner = \"rltakashige\";\n      repo = \"mlx-jaccl-fix-small-recv\";\n      rev = uvLockMlxRev;\n      hash = \"sha256-GosFIWxIB48Egb1MqJrR3xhsUsQeWdRk5rV93USY6wQ=\";\n    };\n\n    patches = [\n      (replaceVars ./darwin-build-fixes.patch {\n        sdkVersion = apple-sdk_26.version;\n        metalVersion = metal-toolchain.metalVersion;\n      })\n    ];\n\n    postPatch = ''\n      substituteInPlace mlx/backend/cpu/jit_compiler.cpp \\\n        --replace-fail \"g++\" \"$CXX\"\n    '';\n\n    dontUseCmakeConfigure = true;\n\n    enableParallelBuilding = true;\n\n    # Allows multiple cores to be used in Python builds.\n    postUnpack = ''\n      export MAKEFLAGS+=\"''${enableParallelBuilding:+-j$NIX_BUILD_CORES}\"\n    '';\n\n    # Updates the wrong fetcher rev attribute\n    passthru.skipBulkUpdate = true;\n\n    env = {\n      DEV_RELEASE = 1;\n      CMAKE_ARGS = toString [\n        (lib.cmakeBool \"USE_SYSTEM_FMT\" true)\n        (lib.cmakeOptionType \"filepath\" \"FETCHCONTENT_SOURCE_DIR_GGUFLIB\" \"${gguf-tools}\")\n        (lib.cmakeOptionType \"filepath\" \"FETCHCONTENT_SOURCE_DIR_JSON\" \"${nlohmann_json.src}\")\n        (lib.cmakeOptionType \"filepath\" \"FETCHCONTENT_SOURCE_DIR_NANOBIND\" \"${nanobind}\")\n        (lib.cmakeBool \"FETCHCONTENT_FULLY_DISCONNECTED\" true)\n        (lib.cmakeBool \"MLX_BUILD_CPU\" true)\n        (lib.cmakeBool \"MLX_BUILD_METAL\" true)\n        (lib.cmakeOptionType \"filepath\" \"FETCHCONTENT_SOURCE_DIR_METAL_CPP\" \"${metal_cpp}\")\n        (lib.cmakeOptionType \"string\" \"CMAKE_OSX_DEPLOYMENT_TARGET\" \"${apple-sdk_26.version}\")\n        (lib.cmakeOptionType \"filepath\" \"CMAKE_OSX_SYSROOT\" \"${apple-sdk_26.passthru.sdkroot}\")\n      ];\n      SDKROOT = apple-sdk_26.passthru.sdkroot;\n      MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;\n    };\n\n    build-system = [\n      python313Packages.setuptools\n    ];\n\n    nativeBuildInputs = [\n      cmake\n      metal-toolchain\n      python313Packages.pypaBuildHook\n      python313Packages.pypaInstallHook\n      python313Packages.setuptools\n      python313Packages.typing-extensions\n      python313Packages.wheel\n      python313Packages.cmake\n      python313Packages.ninja\n    ];\n\n    buildInputs = [\n      fmt\n      gguf-tools\n      python313Packages.nanobind\n      python313Packages.pybind11\n      apple-sdk_26\n    ];\n\n    # Tests require Metal GPU access which isn't available in the Nix sandbox.\n    # To run tests, build with: nix build --option sandbox false .#mlx.passthru.tests.mlxTest\n    doCheck = false;\n\n    pythonImportsCheck = [ \"mlx\" ];\n\n    passthru.tests = {\n      # Runs example scripts to verify MLX works. Requires --option sandbox false\n      # since Metal GPU access is needed.\n      mlxTest =\n        runCommand \"run-mlx-examples\"\n          {\n            buildInputs = [ mlx ];\n            nativeBuildInputs = [ python ];\n          }\n          ''\n            cp ${src}/examples/python/logistic_regression.py .\n            ${python.interpreter} logistic_regression.py\n            rm logistic_regression.py\n\n            cp ${src}/examples/python/linear_regression.py .\n            ${python.interpreter} linear_regression.py\n            rm linear_regression.py\n\n            touch $out\n          '';\n    };\n\n    meta = {\n      homepage = \"https://github.com/ml-explore/mlx\";\n      description = \"Array framework for Apple silicon\";\n      changelog = \"https://github.com/ml-explore/mlx/releases/tag/${src.tag}\";\n      license = lib.licenses.mit;\n      platforms = [ \"aarch64-darwin\" ];\n    };\n  };\nin\nmlx\n"
  },
  {
    "path": "packaging/dmg/create-dmg.sh",
    "content": "#!/usr/bin/env bash\n# create-dmg.sh — Build a polished macOS DMG installer for EXO\n#\n# Usage:\n#   ./packaging/dmg/create-dmg.sh <app-path> <output-dmg> [volume-name]\n#\n# Example:\n#   ./packaging/dmg/create-dmg.sh output/EXO.app EXO-1.0.0.dmg \"EXO\"\n#\n# Creates a DMG with:\n#   - Custom background image with drag-to-Applications arrow\n#   - App icon on left, Applications alias on right\n#   - Proper window size and icon positioning\nset -euo pipefail\n\nAPP_PATH=\"${1:?Usage: create-dmg.sh <app-path> <output-dmg> [volume-name]}\"\nOUTPUT_DMG=\"${2:?Usage: create-dmg.sh <app-path> <output-dmg> [volume-name]}\"\nVOLUME_NAME=\"${3:-EXO}\"\n\nSCRIPT_DIR=\"$(cd \"$(dirname \"$0\")\" && pwd)\"\nBACKGROUND_SCRIPT=\"${SCRIPT_DIR}/generate-background.py\"\nTEMP_DIR=\"$(mktemp -d)\"\nDMG_STAGING=\"${TEMP_DIR}/dmg-root\"\nTEMP_DMG=\"${TEMP_DIR}/temp.dmg\"\nBACKGROUND_PNG=\"${TEMP_DIR}/background.png\"\n\ncleanup() { rm -rf \"$TEMP_DIR\"; }\ntrap cleanup EXIT\n\necho \"==> Creating DMG installer for ${VOLUME_NAME}\"\n\n# ── Step 1: Generate background image ────────────────────────────────────────\nif command -v python3 &>/dev/null; then\n  python3 \"$BACKGROUND_SCRIPT\" \"$BACKGROUND_PNG\"\n  echo \"    Background image generated\"\nelse\n  echo \"    Warning: python3 not found, skipping custom background\"\n  BACKGROUND_PNG=\"\"\nfi\n\n# ── Step 2: Prepare staging directory ─────────────────────────────────────────\nmkdir -p \"$DMG_STAGING\"\ncp -R \"$APP_PATH\" \"$DMG_STAGING/\"\nln -s /Applications \"$DMG_STAGING/Applications\"\n\n# ── Step 3: Create writable DMG ──────────────────────────────────────────────\n# Calculate required size (app size + 20MB headroom)\nAPP_SIZE_KB=$(du -sk \"$APP_PATH\" | cut -f1)\nDMG_SIZE_KB=$((APP_SIZE_KB + 20480))\n\nhdiutil create \\\n  -volname \"$VOLUME_NAME\" \\\n  -size \"${DMG_SIZE_KB}k\" \\\n  -fs HFS+ \\\n  -layout SPUD \\\n  \"$TEMP_DMG\"\n\n# ── Step 4: Mount and configure ──────────────────────────────────────────────\nMOUNT_DIR=$(hdiutil attach \"$TEMP_DMG\" -readwrite -noverify | awk -F'\\t' '/Apple_HFS/ {gsub(/^[[:space:]]+|[[:space:]]+$/, \"\", $NF); print $NF}')\necho \"    Mounted at: $MOUNT_DIR\"\n\n# Copy contents\ncp -R \"$DMG_STAGING/\"* \"$MOUNT_DIR/\"\n\n# Add background image\nif [[ -n $BACKGROUND_PNG && -f $BACKGROUND_PNG ]]; then\n  mkdir -p \"$MOUNT_DIR/.background\"\n  cp \"$BACKGROUND_PNG\" \"$MOUNT_DIR/.background/background.png\"\nfi\n\n# ── Step 5: Configure window appearance via AppleScript ──────────────────────\n# Window: 800×400, app icon on left, Applications on right (matches Ollama layout)\n# Background image is 1600×740 (2× retina for 800×400 logical window).\nAPP_NAME=\"$(basename \"$APP_PATH\")\"\n\nosascript <<APPLESCRIPT\ntell application \"Finder\"\n    tell disk \"$VOLUME_NAME\"\n        open\n        set current view of container window to icon view\n        set toolbar visible of container window to false\n        set statusbar visible of container window to false\n        set bounds of container window to {200, 120, 1000, 520}\n        set opts to icon view options of container window\n        set icon size of opts to 128\n        set text size of opts to 12\n        set arrangement of opts to not arranged\n        if exists file \".background:background.png\" then\n            set background picture of opts to file \".background:background.png\"\n        end if\n        set position of item \"$APP_NAME\" of container window to {200, 190}\n        set position of item \"Applications\" of container window to {600, 190}\n        close\n        open\n        update without registering applications\n        delay 1\n        close\n    end tell\nend tell\nAPPLESCRIPT\n\necho \"    Window layout configured\"\n\n# Ensure Finder updates are flushed\nsync\n\n# ── Step 6: Finalise ─────────────────────────────────────────────────────────\nhdiutil detach \"$MOUNT_DIR\" -quiet\nhdiutil convert \"$TEMP_DMG\" -format UDZO -imagekey zlib-level=9 -o \"$OUTPUT_DMG\"\n\necho \"==> DMG created: $OUTPUT_DMG\"\necho \"    Size: $(du -h \"$OUTPUT_DMG\" | cut -f1)\"\n"
  },
  {
    "path": "packaging/dmg/generate-background.py",
    "content": "#!/usr/bin/env python3\n\"\"\"Generate the DMG background image with a centered drag-to-Applications arrow.\n\nThe output is a 1600×740 retina PNG (2× for 800×400 logical window).\nIcons are positioned at (200, 190) and (600, 190) in logical coordinates;\nthe arrow is drawn centered between them.\n\nUsage:\n    python3 generate-background.py [output.png]\n\nIf no output path is given, overwrites the bundled background.png in-place.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport math\nimport sys\nfrom pathlib import Path\n\nfrom PIL import Image, ImageDraw\n\n# Retina dimensions (2× logical 800×400)\nWIDTH = 1600\nHEIGHT = 740\n\n# Icon positions in logical coords → retina coords\n# App icon at (200, 190), Applications at (600, 190)\nAPP_X = 200 * 2  # 400\nAPPS_X = 600 * 2  # 1200\nICON_Y = 190 * 2  # 380\n\n# Arrow drawn between icons, slightly above icon center\nARROW_START_X = APP_X + 160  # past the icon\nARROW_END_X = APPS_X - 160  # before the Applications icon\nARROW_Y = ICON_Y  # same height as icons\nARROW_RISE = 120  # upward arc height\n\n\ndef draw_arrow(draw: ImageDraw.ImageDraw) -> None:\n    \"\"\"Draw a hand-drawn-style curved arrow from app icon toward Applications.\"\"\"\n    color = (30, 30, 30)\n    line_width = 8\n\n    # Compute bezier curve points for a gentle upward arc\n    points: list[tuple[float, float]] = []\n    steps = 80\n    for i in range(steps + 1):\n        t = i / steps\n        # Quadratic bezier: start → control → end\n        cx = (ARROW_START_X + ARROW_END_X) / 2\n        cy = ARROW_Y - ARROW_RISE\n        x = (1 - t) ** 2 * ARROW_START_X + 2 * (1 - t) * t * cx + t**2 * ARROW_END_X\n        y = (1 - t) ** 2 * ARROW_Y + 2 * (1 - t) * t * cy + t**2 * ARROW_Y\n        points.append((x, y))\n\n    # Draw the curve as connected line segments\n    for i in range(len(points) - 1):\n        draw.line([points[i], points[i + 1]], fill=color, width=line_width)\n\n    # Arrowhead at the end\n    end_x, end_y = points[-1]\n    # Direction from second-to-last to last point\n    prev_x, prev_y = points[-3]\n    angle = math.atan2(end_y - prev_y, end_x - prev_x)\n    head_len = 36\n    head_angle = math.radians(25)\n\n    left_x = end_x - head_len * math.cos(angle - head_angle)\n    left_y = end_y - head_len * math.sin(angle - head_angle)\n    right_x = end_x - head_len * math.cos(angle + head_angle)\n    right_y = end_y - head_len * math.sin(angle + head_angle)\n\n    draw.polygon(\n        [(end_x, end_y), (left_x, left_y), (right_x, right_y)],\n        fill=color,\n    )\n\n\ndef generate_background(output_path: str) -> None:\n    \"\"\"Generate a white DMG background with a centered arrow.\"\"\"\n    img = Image.new(\"RGBA\", (WIDTH, HEIGHT), (255, 255, 255, 255))\n    draw = ImageDraw.Draw(img)\n    draw_arrow(draw)\n    img.save(output_path, \"PNG\")\n\n\nif __name__ == \"__main__\":\n    default_output = str(Path(__file__).parent / \"background.png\")\n    out = sys.argv[1] if len(sys.argv) >= 2 else default_output\n    generate_background(out)\n    print(f\"Background image written to {out}\")\n"
  },
  {
    "path": "packaging/pyinstaller/exo.spec",
    "content": "# -*- mode: python ; coding: utf-8 -*-\n\nimport importlib.util\nimport shutil\nfrom pathlib import Path\n\nfrom PyInstaller.utils.hooks import collect_submodules\n\nPROJECT_ROOT = Path.cwd()\nSOURCE_ROOT = PROJECT_ROOT / \"src\"\nENTRYPOINT = SOURCE_ROOT / \"exo\" / \"__main__.py\"\nDASHBOARD_DIR = PROJECT_ROOT / \"dashboard\" / \"build\"\nRESOURCES_DIR = PROJECT_ROOT / \"resources\"\nEXO_SHARED_MODELS_DIR = SOURCE_ROOT / \"exo\" / \"shared\" / \"models\"\n\nif not ENTRYPOINT.is_file():\n    raise SystemExit(f\"Unable to locate Exo entrypoint: {ENTRYPOINT}\")\n\nif not DASHBOARD_DIR.is_dir():\n    raise SystemExit(f\"Dashboard assets are missing: {DASHBOARD_DIR}\")\n\nif not RESOURCES_DIR.is_dir():\n    raise SystemExit(f\"Resource assets are missing: {RESOURCES_DIR}\")\n\nif not EXO_SHARED_MODELS_DIR.is_dir():\n    raise SystemExit(f\"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}\")\n\nblock_cipher = None\n\n\ndef _module_directory(module_name: str) -> Path:\n    spec = importlib.util.find_spec(module_name)\n    if spec is None:\n        raise SystemExit(f\"Module '{module_name}' is not available in the current environment.\")\n    if spec.submodule_search_locations:\n        return Path(next(iter(spec.submodule_search_locations))).resolve()\n    if spec.origin:\n        return Path(spec.origin).resolve().parent\n    raise SystemExit(f\"Unable to determine installation directory for '{module_name}'.\")\n\n\nMLX_PACKAGE_DIR = _module_directory(\"mlx\")\nMLX_LIB_DIR = MLX_PACKAGE_DIR / \"lib\"\nif not MLX_LIB_DIR.is_dir():\n    raise SystemExit(f\"mlx Metal libraries are missing: {MLX_LIB_DIR}\")\n\n\ndef _safe_collect(package_name: str) -> list[str]:\n    try:\n        return collect_submodules(package_name)\n    except ImportError:\n        return []\n\n\nHIDDEN_IMPORTS = sorted(\n    set(\n        collect_submodules(\"mlx\")\n        + _safe_collect(\"mlx_lm\")\n        + _safe_collect(\"transformers\")\n    )\n)\n\nDATAS: list[tuple[str, str]] = [\n    (str(DASHBOARD_DIR), \"dashboard\"),\n    (str(RESOURCES_DIR), \"resources\"),\n    (str(MLX_LIB_DIR), \"mlx/lib\"),\n    (str(EXO_SHARED_MODELS_DIR), \"exo/shared/models\"),\n]\n\nMACMON_PATH = shutil.which(\"macmon\")\nif MACMON_PATH is None:\n    raise SystemExit(\n        \"macmon binary not found in PATH. \"\n        \"Install it via: brew install macmon\"\n    )\n\nBINARIES: list[tuple[str, str]] = [\n    (MACMON_PATH, \".\"),\n]\n\na = Analysis(\n    [str(ENTRYPOINT)],\n    pathex=[str(SOURCE_ROOT)],\n    binaries=BINARIES,\n    datas=DATAS,\n    hiddenimports=HIDDEN_IMPORTS,\n    hookspath=[],\n    hooksconfig={},\n    runtime_hooks=[],\n    excludes=[],\n    win_no_prefer_redirects=False,\n    win_private_assemblies=False,\n    noarchive=False,\n)\npyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)\nexe = EXE(\n    pyz,\n    a.scripts,\n    [],\n    exclude_binaries=True,\n    name=\"exo\",\n    debug=False,\n    bootloader_ignore_signals=False,\n    strip=False,\n    upx=False,\n    console=True,\n    disable_windowed_traceback=False,\n    argv_emulation=False,\n    target_arch=None,\n    codesign_identity=None,\n    entitlements_file=None,\n)\ncoll = COLLECT(\n    exe,\n    a.binaries,\n    a.zipfiles,\n    a.datas,\n    strip=False,\n    upx=False,\n    upx_exclude=[],\n    name=\"exo\",\n)\n\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"exo\"\nversion = \"0.3.68\"\ndescription = \"Exo\"\nreadme = \"README.md\"\nrequires-python = \">=3.13\"\ndependencies = [\n  \"aiofiles>=24.1.0\",\n  \"aiohttp>=3.12.14\",\n  \"types-aiofiles>=24.1.0.20250708\",\n  \"pydantic>=2.11.7\",\n  \"fastapi>=0.116.1\",\n  \"filelock>=3.18.0\",\n  \"rustworkx>=0.17.1\",\n  \"huggingface-hub>=0.33.4\",\n  \"psutil>=7.0.0\",\n  \"loguru>=0.7.3\",\n  \"exo_pyo3_bindings\",                         # rust bindings\n  \"anyio==4.11.0\",\n  \"mlx; sys_platform == 'darwin'\",\n  \"mlx[cpu]==0.30.6; sys_platform == 'linux'\",\n  \"mlx-lm\",\n  \"tiktoken>=0.12.0\",                          # required for kimi k2 tokenizer\n  \"hypercorn>=0.18.0\",\n  \"openai-harmony>=0.0.8\",\n  \"httpx>=0.28.1\",\n  \"tomlkit>=0.14.0\",\n  \"mflux==0.16.9\",\n  \"python-multipart>=0.0.21\",\n  \"msgspec>=0.19.0\",\n  \"zstandard>=0.23.0\",\n]\n\n[project.scripts]\nexo = \"exo.main:main\"\n\n# dependencies only required for development\n[dependency-groups]\ndev = [\n  \"basedpyright>=1.29.0\",\n  \"pyinstaller>=6.17.0\",\n  \"pytest>=8.4.0\",\n  \"pytest-asyncio>=1.0.0\",\n  \"pytest-env\",\n  \"ruff>=0.11.13\",\n]\n\n# mlx[cuda] requires a newer version of mlx. the ideal on linux is: default to mlx[cpu] unless[cuda] specified.\n[project.optional-dependencies]\n# cuda = [\n#     \"mlx[cuda]==0.26.3\",\n# ]\n\n###\n# workspace configuration\n###\n\n[tool.uv.workspace]\nmembers = [\"rust/exo_pyo3_bindings\", \"bench\"]\n\n[tool.uv.sources]\nexo_pyo3_bindings = { workspace = true }\nmlx = { git = \"https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git\", branch = \"address-rdma-gpu-locks\", marker = \"sys_platform == 'darwin'\" }\nmlx-lm = { git = \"https://github.com/rltakashige/mlx-lm\", branch = \"leo/eval-left-padding-in-batched-rotation\" }\n# Uncomment to use local mlx/mlx-lm development versions:\n# mlx = { path = \"/Users/Shared/mlx\", editable=true }\n# mlx-lm = { path = \"/Users/Shared/mlx-lm\", editable=true }\n\n[build-system]\nrequires = [\"uv_build>=0.8.9,<0.9.0\"]\nbuild-backend = \"uv_build\"\n\n###\n# type-checker configuration\n###\n\n[tool.basedpyright]\ninclude = [\".venv/lib/mlx\", \".venv/lib/mlx_lm\", \"src\", \"bench\"]\ntypeCheckingMode = \"strict\"\nfailOnWarnings = true\n\nreportAny = \"error\"\nreportUnknownVariableType = \"error\"\nreportUnknownParameterType = \"error\"\nreportMissingParameterType = \"error\"\nreportMissingTypeStubs = \"error\"\nreportInvalidCast = \"error\"\nreportUnnecessaryCast = \"error\"\nreportUnnecessaryTypeIgnoreComment = \"error\"\n\npythonVersion = \"3.13\"\npythonPlatform = \"Darwin\"\n\nexclude = [\n  \"**/.venv\",\n  \"**/venv\",\n  \"**/__pycache__\",\n  \"**/exo_scripts\",\n  \"**/.direnv\",\n  \"**/rust\",\n  \"**/.github\",\n]\nstubPath = \".mlx_typings\"\n\n[[tool.basedpyright.executionEnvironments]]\nroot = \"src\"\n\n###\n# uv configuration\n###\n\n# supported platforms for this project\n[tool.uv]\nrequired-version = \">=0.8.6\"\nprerelease = \"allow\"\nenvironments = [\"sys_platform == 'darwin'\", \"sys_platform == 'linux'\"]\n\n###\n# ruff configuration\n###\n\n[tool.ruff]\nextend-exclude = [\n  \"shared/protobufs/**\",\n  \"*mlx_typings/**\",\n  \"rust/exo_pyo3_bindings/**\",\n  \"bench/vendor/**\",\n]\n\n[tool.ruff.lint]\nextend-select = [\"I\", \"N\", \"B\", \"A\", \"PIE\", \"SIM\"]\n\n[tool.pytest.ini_options]\npythonpath = \".\"\nasyncio_mode = \"auto\"\nmarkers = [\"slow: marks tests as slow (deselected by default)\"]\nenv = [\"EXO_TESTS=1\"]\naddopts = \"-m 'not slow' --ignore=tests/start_distributed_test.py\"\nfilterwarnings = [\"ignore:builtin type Swig:DeprecationWarning\"]\n"
  },
  {
    "path": "python/parts.nix",
    "content": "{ inputs, ... }:\n{\n  perSystem =\n    { config, self', pkgs, lib, system, ... }:\n    let\n      # Load workspace from uv.lock\n      workspace = inputs.uv2nix.lib.workspace.loadWorkspace {\n        workspaceRoot = inputs.self;\n      };\n\n      # Create overlay from workspace\n      # Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build\n      overlay = workspace.mkPyprojectOverlay { sourcePreference = \"wheel\"; };\n\n      # Override overlay to inject Nix-built components\n      exoOverlay = final: prev: {\n        # Replace workspace exo_pyo3_bindings with Nix-built wheel.\n        # Preserve passthru so mkVirtualEnv can resolve dependency groups.\n        # Copy .pyi stub + py.typed marker so basedpyright can find the types.\n        exo-pyo3-bindings = pkgs.stdenv.mkDerivation {\n          pname = \"exo-pyo3-bindings\";\n          version = \"0.1.0\";\n          src = self'.packages.exo_pyo3_bindings;\n          # Install from pre-built wheel\n          nativeBuildInputs = [ final.pyprojectWheelHook ];\n          dontStrip = true;\n          passthru = prev.exo-pyo3-bindings.passthru or { };\n          postInstall = ''\n            local siteDir=$out/${final.python.sitePackages}/exo_pyo3_bindings\n            cp ${inputs.self}/rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi $siteDir/\n            touch $siteDir/py.typed\n          '';\n        };\n      };\n\n      python = pkgs.python313;\n\n      # Overlay to provide build systems and custom packages\n      buildSystemsOverlay = final: prev: {\n        # mlx-lm is a git dependency that needs setuptools\n        mlx-lm = prev.mlx-lm.overrideAttrs (old: {\n          nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [\n            final.setuptools\n          ];\n        });\n        # rouge-score and sacrebleu don't declare setuptools as a build dependency\n        rouge-score = prev.rouge-score.overrideAttrs (old: {\n          nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [\n            final.setuptools\n          ];\n        });\n        sacrebleu = prev.sacrebleu.overrideAttrs (old: {\n          nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [\n            final.setuptools\n          ];\n        });\n        sqlitedict = prev.sqlitedict.overrideAttrs (old: {\n          nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [\n            final.setuptools\n          ];\n        });\n        word2number = prev.word2number.overrideAttrs (old: {\n          nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [\n            final.setuptools\n          ];\n        });\n      } // lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {\n        # Use our pure Nix-built MLX with Metal support (macOS only)\n        mlx = self'.packages.mlx;\n      };\n\n      # Additional overlay for Linux-specific fixes (type checking env).\n      # Native wheels have shared lib dependencies we don't need at type-check time.\n      linuxOverlay = final: prev:\n        let\n          ignoreMissing = drv: drv.overrideAttrs { autoPatchelfIgnoreMissingDeps = [ \"*\" ]; };\n          nvidiaPackages = lib.filterAttrs (name: _: lib.hasPrefix \"nvidia-\" name) prev;\n        in\n        lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (\n          (lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {\n            mlx = ignoreMissing prev.mlx;\n            mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {\n              buildInputs = (old.buildInputs or [ ]) ++ [\n                final.nvidia-cublas\n                final.nvidia-cuda-nvrtc\n                final.nvidia-cudnn-cu13\n                final.nvidia-nccl-cu13\n              ];\n              preFixup = ''\n                addAutoPatchelfSearchPath ${final.nvidia-cublas}\n                addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}\n                addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}\n                addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}\n              '';\n              autoPatchelfIgnoreMissingDeps = [ \"libcuda.so.1\" ];\n            });\n            torch = ignoreMissing prev.torch;\n            triton = ignoreMissing prev.triton;\n          }\n        );\n\n      pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {\n        inherit python;\n      }).overrideScope (\n        lib.composeManyExtensions [\n          inputs.pyproject-build-systems.overlays.default\n          overlay\n          exoOverlay\n          buildSystemsOverlay\n          linuxOverlay\n        ]\n      );\n      # mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.\n      # mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.\n      venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [\n        \"lib/python3.13/site-packages/mlx*\"\n        \"lib/python3.13/site-packages/nvidia*\"\n      ];\n\n      # Exclude bench deps from main env (bench has its own benchVenv)\n      exoDeps = removeAttrs workspace.deps.default [ \"exo-bench\" ];\n\n      exoVenv = (pythonSet.mkVirtualEnv \"exo-env\" exoDeps).overrideAttrs {\n        venvIgnoreCollisions = venvCollisionPaths;\n      };\n\n      # Virtual environment with dev dependencies for testing\n      testVenv = (pythonSet.mkVirtualEnv \"exo-test-env\" (\n        exoDeps // {\n          exo = [ \"dev\" ]; # Include pytest, pytest-asyncio, pytest-env\n        }\n      )).overrideAttrs {\n        venvIgnoreCollisions = venvCollisionPaths;\n      };\n\n      mkPythonScript = name: path: pkgs.writeShellApplication {\n        inherit name;\n        runtimeInputs = [ exoVenv ];\n        runtimeEnv = {\n          EXO_DASHBOARD_DIR = self'.packages.dashboard;\n          EXO_RESOURCES_DIR = inputs.self + /resources;\n        };\n        text = ''exec python ${path} \"$@\"'';\n      };\n\n      benchVenv = pythonSet.mkVirtualEnv \"exo-bench-env\" {\n        exo-bench = [ ];\n      };\n\n      mkBenchScript = name: path: pkgs.writeShellApplication {\n        inherit name;\n        runtimeInputs = [ benchVenv ];\n        text = ''exec python ${path} \"$@\"'';\n      };\n\n      mkSimplePythonScript = name: path: pkgs.writeShellApplication {\n        inherit name;\n        runtimeInputs = [ pkgs.python313 ];\n        text = ''exec python ${path} \"$@\"'';\n      };\n\n      exoPackage = pkgs.runCommand \"exo\"\n        {\n          nativeBuildInputs = [ pkgs.makeWrapper ];\n        }\n        ''\n          mkdir -p $out/bin\n\n          # Create wrapper script\n          makeWrapper ${exoVenv}/bin/exo $out/bin/exo \\\n            --set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \\\n            --set EXO_RESOURCES_DIR ${inputs.self + /resources} \\\n            ${lib.optionalString pkgs.stdenv.hostPlatform.isDarwin \"--prefix PATH : ${pkgs.macmon}/bin\"}\n        '';\n    in\n    {\n      # Python package only available on macOS (requires MLX/Metal)\n      packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin\n        {\n          exo = exoPackage;\n          # Test environment for running pytest outside of Nix sandbox (needs GPU access)\n          exo-test-env = testVenv;\n        } // {\n        exo-bench = mkBenchScript \"exo-bench\" (inputs.self + /bench/exo_bench.py);\n        exo-eval = mkBenchScript \"exo-eval\" (inputs.self + /bench/exo_eval.py);\n        exo-eval-tool-calls = mkBenchScript \"exo-eval-tool-calls\" (inputs.self + /bench/eval_tool_calls.py);\n        exo-get-all-models-on-cluster = mkSimplePythonScript \"exo-get-all-models-on-cluster\" (inputs.self + /tests/get_all_models_on_cluster.py);\n      };\n\n      checks = {\n        # Ruff linting (works on all platforms)\n        lint = pkgs.runCommand \"ruff-lint\" { } ''\n          export RUFF_CACHE_DIR=\"$TMPDIR/ruff-cache\"\n          ${pkgs.ruff}/bin/ruff check ${inputs.self}\n          touch $out\n        '';\n\n        # Hermetic basedpyright type checking\n        typecheck = pkgs.runCommand \"typecheck\"\n          {\n            nativeBuildInputs = [\n              testVenv\n              pkgs.basedpyright\n            ];\n          }\n          ''\n            cd ${inputs.self}\n            export HOME=$TMPDIR\n            basedpyright --pythonpath ${testVenv}/bin/python\n            touch $out\n          '';\n      };\n    };\n}\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--FLUX.1-Kontext-dev-4bit.toml",
    "content": "model_id = \"exolabs/FLUX.1-Kontext-dev-4bit\"\nn_layers = 57\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"ImageToImage\"]\nfamily = \"flux\"\nquantization = \"4bit\"\nbase_model = \"FLUX.1 Kontext\"\ncapabilities = [\"image_edit\"]\n\n[storage_size]\nin_bytes = 15475325472\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n\n[[components]]\ncomponent_name = \"text_encoder_2\"\ncomponent_path = \"text_encoder_2/\"\nn_layers = 24\ncan_shard = false\nsafetensors_index_filename = \"model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 9524621312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 57\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 5950704160\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--FLUX.1-Kontext-dev-8bit.toml",
    "content": "model_id = \"exolabs/FLUX.1-Kontext-dev-8bit\"\nn_layers = 57\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"ImageToImage\"]\nfamily = \"flux\"\nquantization = \"8bit\"\nbase_model = \"FLUX.1 Kontext\"\ncapabilities = [\"image_edit\"]\n\n[storage_size]\nin_bytes = 21426029632\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n\n[[components]]\ncomponent_name = \"text_encoder_2\"\ncomponent_path = \"text_encoder_2/\"\nn_layers = 24\ncan_shard = false\nsafetensors_index_filename = \"model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 9524621312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 57\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 11901408320\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--FLUX.1-Kontext-dev.toml",
    "content": "model_id = \"exolabs/FLUX.1-Kontext-dev\"\nn_layers = 57\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"ImageToImage\"]\nfamily = \"flux\"\nquantization = \"\"\nbase_model = \"FLUX.1 Kontext\"\ncapabilities = [\"image_edit\"]\n\n[storage_size]\nin_bytes = 33327437952\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n\n[[components]]\ncomponent_name = \"text_encoder_2\"\ncomponent_path = \"text_encoder_2/\"\nn_layers = 24\ncan_shard = false\nsafetensors_index_filename = \"model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 9524621312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 57\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 23802816640\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--FLUX.1-Krea-dev-4bit.toml",
    "content": "model_id = \"exolabs/FLUX.1-Krea-dev-4bit\"\nn_layers = 57\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"TextToImage\"]\nfamily = \"flux\"\nquantization = \"4bit\"\nbase_model = \"FLUX.1 Krea\"\ncapabilities = [\"image_gen\"]\n\n[storage_size]\nin_bytes = 15475325472\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n\n[[components]]\ncomponent_name = \"text_encoder_2\"\ncomponent_path = \"text_encoder_2/\"\nn_layers = 24\ncan_shard = false\nsafetensors_index_filename = \"model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 9524621312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 57\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 5950704160\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--FLUX.1-Krea-dev-8bit.toml",
    "content": "model_id = \"exolabs/FLUX.1-Krea-dev-8bit\"\nn_layers = 57\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"TextToImage\"]\nfamily = \"flux\"\nquantization = \"8bit\"\nbase_model = \"FLUX.1 Krea\"\ncapabilities = [\"image_gen\"]\n\n[storage_size]\nin_bytes = 21426029632\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n\n[[components]]\ncomponent_name = \"text_encoder_2\"\ncomponent_path = \"text_encoder_2/\"\nn_layers = 24\ncan_shard = false\nsafetensors_index_filename = \"model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 9524621312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 57\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 11901408320\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--FLUX.1-Krea-dev.toml",
    "content": "model_id = \"exolabs/FLUX.1-Krea-dev\"\nn_layers = 57\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"TextToImage\"]\nfamily = \"flux\"\nquantization = \"\"\nbase_model = \"FLUX.1 Krea\"\ncapabilities = [\"image_gen\"]\n\n[storage_size]\nin_bytes = 33327437952\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n\n[[components]]\ncomponent_name = \"text_encoder_2\"\ncomponent_path = \"text_encoder_2/\"\nn_layers = 24\ncan_shard = false\nsafetensors_index_filename = \"model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 9524621312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 57\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 23802816640\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--FLUX.1-dev-4bit.toml",
    "content": "model_id = \"exolabs/FLUX.1-dev-4bit\"\nn_layers = 57\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"TextToImage\"]\nfamily = \"flux\"\nquantization = \"4bit\"\nbase_model = \"FLUX.1 Dev\"\ncapabilities = [\"image_gen\"]\n\n[storage_size]\nin_bytes = 15475325472\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n\n[[components]]\ncomponent_name = \"text_encoder_2\"\ncomponent_path = \"text_encoder_2/\"\nn_layers = 24\ncan_shard = false\nsafetensors_index_filename = \"model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 9524621312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 57\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 5950704160\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--FLUX.1-dev-8bit.toml",
    "content": "model_id = \"exolabs/FLUX.1-dev-8bit\"\nn_layers = 57\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"TextToImage\"]\nfamily = \"flux\"\nquantization = \"8bit\"\nbase_model = \"FLUX.1 Dev\"\ncapabilities = [\"image_gen\"]\n\n[storage_size]\nin_bytes = 21426029632\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n\n[[components]]\ncomponent_name = \"text_encoder_2\"\ncomponent_path = \"text_encoder_2/\"\nn_layers = 24\ncan_shard = false\nsafetensors_index_filename = \"model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 9524621312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 57\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 11901408320\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--FLUX.1-dev.toml",
    "content": "model_id = \"exolabs/FLUX.1-dev\"\nn_layers = 57\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"TextToImage\"]\nfamily = \"flux\"\nquantization = \"\"\nbase_model = \"FLUX.1 Dev\"\ncapabilities = [\"image_gen\"]\n\n[storage_size]\nin_bytes = 33327437952\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n\n[[components]]\ncomponent_name = \"text_encoder_2\"\ncomponent_path = \"text_encoder_2/\"\nn_layers = 24\ncan_shard = false\nsafetensors_index_filename = \"model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 9524621312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 57\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 23802816640\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--FLUX.1-schnell-4bit.toml",
    "content": "model_id = \"exolabs/FLUX.1-schnell-4bit\"\nn_layers = 57\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"TextToImage\"]\nfamily = \"flux\"\nquantization = \"4bit\"\nbase_model = \"FLUX.1 Schnell\"\ncapabilities = [\"image_gen\"]\n\n[storage_size]\nin_bytes = 15470210592\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n\n[[components]]\ncomponent_name = \"text_encoder_2\"\ncomponent_path = \"text_encoder_2/\"\nn_layers = 24\ncan_shard = false\nsafetensors_index_filename = \"model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 9524621312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 57\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 5945589280\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--FLUX.1-schnell-8bit.toml",
    "content": "model_id = \"exolabs/FLUX.1-schnell-8bit\"\nn_layers = 57\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"TextToImage\"]\nfamily = \"flux\"\nquantization = \"8bit\"\nbase_model = \"FLUX.1 Schnell\"\ncapabilities = [\"image_gen\"]\n\n[storage_size]\nin_bytes = 21415799872\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n\n[[components]]\ncomponent_name = \"text_encoder_2\"\ncomponent_path = \"text_encoder_2/\"\nn_layers = 24\ncan_shard = false\nsafetensors_index_filename = \"model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 9524621312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 57\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 11891178560\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--FLUX.1-schnell.toml",
    "content": "model_id = \"exolabs/FLUX.1-schnell\"\nn_layers = 57\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"TextToImage\"]\nfamily = \"flux\"\nquantization = \"\"\nbase_model = \"FLUX.1 Schnell\"\ncapabilities = [\"image_gen\"]\n\n[storage_size]\nin_bytes = 33306978432\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n\n[[components]]\ncomponent_name = \"text_encoder_2\"\ncomponent_path = \"text_encoder_2/\"\nn_layers = 24\ncan_shard = false\nsafetensors_index_filename = \"model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 9524621312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 57\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 23782357120\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--Qwen-Image-4bit.toml",
    "content": "model_id = \"exolabs/Qwen-Image-4bit\"\nn_layers = 60\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"TextToImage\"]\nuses_cfg = true\nfamily = \"qwen-image\"\nquantization = \"4bit\"\nbase_model = \"Qwen Image\"\ncapabilities = [\"image_gen\"]\n\n[storage_size]\nin_bytes = 26799533856\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 16584333312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 60\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 10215200544\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--Qwen-Image-8bit.toml",
    "content": "model_id = \"exolabs/Qwen-Image-8bit\"\nn_layers = 60\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"TextToImage\"]\nuses_cfg = true\nfamily = \"qwen-image\"\nquantization = \"8bit\"\nbase_model = \"Qwen Image\"\ncapabilities = [\"image_gen\"]\n\n[storage_size]\nin_bytes = 37014734400\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 16584333312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 60\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 20430401088\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--Qwen-Image-Edit-2509-4bit.toml",
    "content": "model_id = \"exolabs/Qwen-Image-Edit-2509-4bit\"\nn_layers = 60\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"ImageToImage\"]\nuses_cfg = true\nfamily = \"qwen-image\"\nquantization = \"4bit\"\nbase_model = \"Qwen Image Edit\"\ncapabilities = [\"image_edit\"]\n\n[storage_size]\nin_bytes = 26799533856\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 16584333312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 60\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 10215200544\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--Qwen-Image-Edit-2509-8bit.toml",
    "content": "model_id = \"exolabs/Qwen-Image-Edit-2509-8bit\"\nn_layers = 60\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"ImageToImage\"]\nuses_cfg = true\nfamily = \"qwen-image\"\nquantization = \"8bit\"\nbase_model = \"Qwen Image Edit\"\ncapabilities = [\"image_edit\"]\n\n[storage_size]\nin_bytes = 37014734400\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 16584333312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 60\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 20430401088\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--Qwen-Image-Edit-2509.toml",
    "content": "model_id = \"exolabs/Qwen-Image-Edit-2509\"\nn_layers = 60\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"ImageToImage\"]\nuses_cfg = true\nfamily = \"qwen-image\"\nquantization = \"\"\nbase_model = \"Qwen Image Edit\"\ncapabilities = [\"image_edit\"]\n\n[storage_size]\nin_bytes = 57445135488\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 16584333312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 60\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 40860802176\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/image_model_cards/exolabs--Qwen-Image.toml",
    "content": "model_id = \"exolabs/Qwen-Image\"\nn_layers = 60\nhidden_size = 1\nsupports_tensor = false\ntasks = [\"TextToImage\"]\nuses_cfg = true\nfamily = \"qwen-image\"\nquantization = \"\"\nbase_model = \"Qwen Image\"\ncapabilities = [\"image_gen\"]\n\n[storage_size]\nin_bytes = 57445135488\n\n[[components]]\ncomponent_name = \"text_encoder\"\ncomponent_path = \"text_encoder/\"\nn_layers = 12\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 16584333312\n\n[[components]]\ncomponent_name = \"transformer\"\ncomponent_path = \"transformer/\"\nn_layers = 60\ncan_shard = true\nsafetensors_index_filename = \"diffusion_pytorch_model.safetensors.index.json\"\n\n[components.storage_size]\nin_bytes = 40860802176\n\n[[components]]\ncomponent_name = \"vae\"\ncomponent_path = \"vae/\"\ncan_shard = false\n\n[components.storage_size]\nin_bytes = 0\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--DeepSeek-V3.1-4bit.toml",
    "content": "model_id = \"mlx-community/DeepSeek-V3.1-4bit\"\nn_layers = 61\nhidden_size = 7168\nnum_key_value_heads = 128\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"deepseek\"\nquantization = \"4bit\"\nbase_model = \"DeepSeek V3.1\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 405874409472\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--DeepSeek-V3.1-8bit.toml",
    "content": "model_id = \"mlx-community/DeepSeek-V3.1-8bit\"\nn_layers = 61\nhidden_size = 7168\nnum_key_value_heads = 128\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"deepseek\"\nquantization = \"8bit\"\nbase_model = \"DeepSeek V3.1\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 765577920512\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--GLM-4.5-Air-8bit.toml",
    "content": "model_id = \"mlx-community/GLM-4.5-Air-8bit\"\nn_layers = 46\nhidden_size = 4096\nnum_key_value_heads = 8\nsupports_tensor = false\ntasks = [\"TextGeneration\"]\nfamily = \"glm\"\nquantization = \"8bit\"\nbase_model = \"GLM 4.5 Air\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 122406567936\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--GLM-4.5-Air-bf16.toml",
    "content": "model_id = \"mlx-community/GLM-4.5-Air-bf16\"\nn_layers = 46\nhidden_size = 4096\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"glm\"\nquantization = \"bf16\"\nbase_model = \"GLM 4.5 Air\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 229780750336\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--GLM-4.7-4bit.toml",
    "content": "model_id = \"mlx-community/GLM-4.7-4bit\"\nn_layers = 91\nhidden_size = 5120\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"glm\"\nquantization = \"4bit\"\nbase_model = \"GLM 4.7\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 198556925568\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--GLM-4.7-6bit.toml",
    "content": "model_id = \"mlx-community/GLM-4.7-6bit\"\nn_layers = 91\nhidden_size = 5120\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"glm\"\nquantization = \"6bit\"\nbase_model = \"GLM 4.7\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 286737579648\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--GLM-4.7-8bit-gs32.toml",
    "content": "model_id = \"mlx-community/GLM-4.7-8bit-gs32\"\nn_layers = 91\nhidden_size = 5120\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"glm\"\nquantization = \"8bit\"\nbase_model = \"GLM 4.7\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 396963397248\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--GLM-4.7-Flash-4bit.toml",
    "content": "model_id = \"mlx-community/GLM-4.7-Flash-4bit\"\nn_layers = 47\nhidden_size = 2048\nnum_key_value_heads = 20\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"glm\"\nquantization = \"4bit\"\nbase_model = \"GLM 4.7 Flash\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 19327352832\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--GLM-4.7-Flash-5bit.toml",
    "content": "model_id = \"mlx-community/GLM-4.7-Flash-5bit\"\nn_layers = 47\nhidden_size = 2048\nnum_key_value_heads = 20\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"glm\"\nquantization = \"5bit\"\nbase_model = \"GLM 4.7 Flash\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 22548578304\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--GLM-4.7-Flash-6bit.toml",
    "content": "model_id = \"mlx-community/GLM-4.7-Flash-6bit\"\nn_layers = 47\nhidden_size = 2048\nnum_key_value_heads = 20\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"glm\"\nquantization = \"6bit\"\nbase_model = \"GLM 4.7 Flash\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 26843545600\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--GLM-4.7-Flash-8bit.toml",
    "content": "model_id = \"mlx-community/GLM-4.7-Flash-8bit\"\nn_layers = 47\nhidden_size = 2048\nnum_key_value_heads = 20\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"glm\"\nquantization = \"8bit\"\nbase_model = \"GLM 4.7 Flash\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 34359738368\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--GLM-5-8bit.toml",
    "content": "model_id = \"mlx-community/GLM-5-8bit-MXFP8\"\nn_layers = 78\nhidden_size = 6144\nnum_key_value_heads = 64\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"glm\"\nquantization = \"8bit\"\nbase_model = \"GLM-5\"\ncapabilities = [\"text\", \"thinking\"]\n\n[storage_size]\nin_bytes = 790517400864\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--GLM-5-MXFP4-Q8.toml",
    "content": "model_id = \"mlx-community/GLM-5-MXFP4-Q8\"\nn_layers = 78\nhidden_size = 6144\nnum_key_value_heads = 64\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"glm\"\nquantization = \"MXFP4-Q8\"\nbase_model = \"GLM-5\"\ncapabilities = [\"text\", \"thinking\"]\n\n[storage_size]\nin_bytes = 405478939008\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--GLM-5-bf16.toml",
    "content": "model_id = \"mlx-community/GLM-5\"\nn_layers = 78\nhidden_size = 6144\nnum_key_value_heads = 64\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"glm\"\nquantization = \"bf16\"\nbase_model = \"GLM-5\"\ncapabilities = [\"text\", \"thinking\"]\n\n[storage_size]\nin_bytes = 1487822475264\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Kimi-K2-Instruct-4bit.toml",
    "content": "model_id = \"mlx-community/Kimi-K2-Instruct-4bit\"\nn_layers = 61\nhidden_size = 7168\nnum_key_value_heads = 64\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"kimi\"\nquantization = \"4bit\"\nbase_model = \"Kimi K2\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 620622774272\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Kimi-K2-Thinking.toml",
    "content": "model_id = \"mlx-community/Kimi-K2-Thinking\"\nn_layers = 61\nhidden_size = 7168\nnum_key_value_heads = 64\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"kimi\"\nquantization = \"\"\nbase_model = \"Kimi K2\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 706522120192\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Kimi-K2.5.toml",
    "content": "model_id = \"mlx-community/Kimi-K2.5\"\nn_layers = 61\nhidden_size = 7168\nnum_key_value_heads = 64\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"kimi\"\nquantization = \"\"\nbase_model = \"Kimi K2.5\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 662498705408\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Llama-3.1-Nemotron-70B-Instruct-HF-4bit.toml",
    "content": "model_id = \"mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-4bit\"\nn_layers = 80\nhidden_size = 8192\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"4bit\"\nbase_model = \"NVIDIA Llama-3.1-Nemotron-70B-Instruct\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 39688355840\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Llama-3.1-Nemotron-70B-Instruct-HF-8bit.toml",
    "content": "model_id = \"mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-8bit\"\nn_layers = 80\nhidden_size = 8192\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"8bit\"\nbase_model = \"NVIDIA Llama-3.1-Nemotron-70B-Instruct\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 74964549632\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Llama-3.1-Nemotron-70B-Instruct-HF-bf16.toml",
    "content": "model_id = \"mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16\"\nn_layers = 80\nhidden_size = 8192\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"bf16\"\nbase_model = \"NVIDIA Llama-3.1-Nemotron-70B-Instruct\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 141107412992\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Llama-3.1-Nemotron-Nano-4B-v1.1-4bit.toml",
    "content": "model_id = \"mlx-community/Llama-3.1-Nemotron-Nano-4B-v1.1-4bit\"\nn_layers = 32\nhidden_size = 3072\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"4bit\"\nbase_model = \"NVIDIA Llama-3.1-Nemotron-Nano-4B-v1.1\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 2538706944\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Llama-3.1-Nemotron-Nano-4B-v1.1-8bit.toml",
    "content": "model_id = \"mlx-community/Llama-3.1-Nemotron-Nano-4B-v1.1-8bit\"\nn_layers = 32\nhidden_size = 3072\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"8bit\"\nbase_model = \"NVIDIA Llama-3.1-Nemotron-Nano-4B-v1.1\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 4794980352\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Llama-3.1-Nemotron-Nano-4B-v1.1-bf16.toml",
    "content": "model_id = \"mlx-community/Llama-3.1-Nemotron-Nano-4B-v1.1-bf16\"\nn_layers = 32\nhidden_size = 3072\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"bf16\"\nbase_model = \"NVIDIA Llama-3.1-Nemotron-Nano-4B-v1.1\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 9025492992\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Llama-3.2-1B-Instruct-4bit.toml",
    "content": "model_id = \"mlx-community/Llama-3.2-1B-Instruct-4bit\"\nn_layers = 16\nhidden_size = 2048\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"4bit\"\nbase_model = \"Llama 3.2 1B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 729808896\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Llama-3.2-3B-Instruct-4bit.toml",
    "content": "model_id = \"mlx-community/Llama-3.2-3B-Instruct-4bit\"\nn_layers = 28\nhidden_size = 3072\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"4bit\"\nbase_model = \"Llama 3.2 3B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 1863319552\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Llama-3.2-3B-Instruct-8bit.toml",
    "content": "model_id = \"mlx-community/Llama-3.2-3B-Instruct-8bit\"\nn_layers = 28\nhidden_size = 3072\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"8bit\"\nbase_model = \"Llama 3.2 3B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 3501195264\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Llama-3.3-70B-Instruct-4bit.toml",
    "content": "model_id = \"mlx-community/Llama-3.3-70B-Instruct-4bit\"\nn_layers = 80\nhidden_size = 8192\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"4bit\"\nbase_model = \"Llama 3.3 70B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 40652242944\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Llama-3.3-70B-Instruct-8bit.toml",
    "content": "model_id = \"mlx-community/Llama-3.3-70B-Instruct-8bit\"\nn_layers = 80\nhidden_size = 8192\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"8bit\"\nbase_model = \"Llama 3.3 70B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 76799803392\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Meta-Llama-3.1-70B-Instruct-4bit.toml",
    "content": "model_id = \"mlx-community/Meta-Llama-3.1-70B-Instruct-4bit\"\nn_layers = 80\nhidden_size = 8192\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"4bit\"\nbase_model = \"Llama 3.1 70B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 40652242944\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Meta-Llama-3.1-8B-Instruct-4bit.toml",
    "content": "model_id = \"mlx-community/Meta-Llama-3.1-8B-Instruct-4bit\"\nn_layers = 32\nhidden_size = 4096\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"4bit\"\nbase_model = \"Llama 3.1 8B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 4637851648\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Meta-Llama-3.1-8B-Instruct-8bit.toml",
    "content": "model_id = \"mlx-community/Meta-Llama-3.1-8B-Instruct-8bit\"\nn_layers = 32\nhidden_size = 4096\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"8bit\"\nbase_model = \"Llama 3.1 8B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 8954839040\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Meta-Llama-3.1-8B-Instruct-bf16.toml",
    "content": "model_id = \"mlx-community/Meta-Llama-3.1-8B-Instruct-bf16\"\nn_layers = 32\nhidden_size = 4096\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"bf16\"\nbase_model = \"Llama 3.1 8B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 16882073600\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--MiniMax-M2.1-3bit.toml",
    "content": "model_id = \"mlx-community/MiniMax-M2.1-3bit\"\nn_layers = 61\nhidden_size = 3072\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"minimax\"\nquantization = \"3bit\"\nbase_model = \"MiniMax M2.1\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 100086644736\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--MiniMax-M2.1-8bit.toml",
    "content": "model_id = \"mlx-community/MiniMax-M2.1-8bit\"\nn_layers = 61\nhidden_size = 3072\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"minimax\"\nquantization = \"8bit\"\nbase_model = \"MiniMax M2.1\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 242986745856\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--MiniMax-M2.5-4bit.toml",
    "content": "model_id = \"mlx-community/MiniMax-M2.5-4bit\"\nn_layers = 62\nhidden_size = 3072\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"minimax\"\nquantization = \"4bit\"\nbase_model = \"MiniMax M2.5\"\ncapabilities = [\"text\", \"thinking\"]\n\n[storage_size]\nin_bytes = 128666664960\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--MiniMax-M2.5-6bit.toml",
    "content": "model_id = \"mlx-community/MiniMax-M2.5-6bit\"\nn_layers = 62\nhidden_size = 3072\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"minimax\"\nquantization = \"6bit\"\nbase_model = \"MiniMax M2.5\"\ncapabilities = [\"text\", \"thinking\"]\n\n[storage_size]\nin_bytes = 185826705408\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--MiniMax-M2.5-8bit.toml",
    "content": "model_id = \"mlx-community/MiniMax-M2.5-8bit\"\nn_layers = 62\nhidden_size = 3072\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"minimax\"\nquantization = \"8bit\"\nbase_model = \"MiniMax M2.5\"\ncapabilities = [\"text\", \"thinking\"]\n\n[storage_size]\nin_bytes = 242986745856\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-4Bit.toml",
    "content": "model_id = \"mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-4Bit\"\nn_layers = 52\nhidden_size = 2688\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"nemotron\"\nquantization = \"4bit\"\nbase_model = \"NVIDIA Nemotron-3-Nano-30B-A3B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 17775342336\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-5Bit.toml",
    "content": "model_id = \"mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-5Bit\"\nn_layers = 52\nhidden_size = 2688\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"nemotron\"\nquantization = \"5bit\"\nbase_model = \"NVIDIA Nemotron-3-Nano-30B-A3B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 21721476864\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-6Bit.toml",
    "content": "model_id = \"mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-6Bit\"\nn_layers = 52\nhidden_size = 2688\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"nemotron\"\nquantization = \"6bit\"\nbase_model = \"NVIDIA Nemotron-3-Nano-30B-A3B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 25667611392\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-8Bit.toml",
    "content": "model_id = \"mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-8Bit\"\nn_layers = 52\nhidden_size = 2688\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"nemotron\"\nquantization = \"8bit\"\nbase_model = \"NVIDIA Nemotron-3-Nano-30B-A3B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 33559880448\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-BF16.toml",
    "content": "model_id = \"mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-BF16\"\nn_layers = 52\nhidden_size = 2688\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"nemotron\"\nquantization = \"bf16\"\nbase_model = \"NVIDIA Nemotron-3-Nano-30B-A3B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 63155889408\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-MXFP4.toml",
    "content": "model_id = \"mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-MXFP4\"\nn_layers = 52\nhidden_size = 2688\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"nemotron\"\nquantization = \"4bit\"\nbase_model = \"NVIDIA Nemotron-3-Nano-30B-A3B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 16788808704\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4.toml",
    "content": "model_id = \"mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4\"\nn_layers = 52\nhidden_size = 2688\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"nemotron\"\nquantization = \"4bit\"\nbase_model = \"NVIDIA Nemotron-3-Nano-30B-A3B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 19323906944\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--NVIDIA-Nemotron-Nano-9B-v2-4bits.toml",
    "content": "model_id = \"mlx-community/NVIDIA-Nemotron-Nano-9B-v2-4bits\"\nn_layers = 56\nhidden_size = 4480\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"nemotron\"\nquantization = \"4bit\"\nbase_model = \"NVIDIA Nemotron-Nano-9B-v2\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 5002791936\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--NVIDIA-Nemotron-Nano-9B-v2-6bit.toml",
    "content": "model_id = \"mlx-community/NVIDIA-Nemotron-Nano-9B-v2-6bit\"\nn_layers = 56\nhidden_size = 4480\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"nemotron\"\nquantization = \"6bit\"\nbase_model = \"NVIDIA Nemotron-Nano-9B-v2\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 7224298496\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-0.6B-4bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-0.6B-4bit\"\nn_layers = 28\nhidden_size = 1024\nnum_key_value_heads = 8\nsupports_tensor = false\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"4bit\"\nbase_model = \"Qwen3 0.6B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 342884352\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-0.6B-8bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-0.6B-8bit\"\nn_layers = 28\nhidden_size = 1024\nnum_key_value_heads = 8\nsupports_tensor = false\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"8bit\"\nbase_model = \"Qwen3 0.6B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 698351616\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-235B-A22B-Instruct-2507-4bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit\"\nn_layers = 94\nhidden_size = 4096\nnum_key_value_heads = 4\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"4bit\"\nbase_model = \"Qwen3 235B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 141733920768\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-235B-A22B-Instruct-2507-8bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit\"\nn_layers = 94\nhidden_size = 4096\nnum_key_value_heads = 4\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"8bit\"\nbase_model = \"Qwen3 235B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 268435456000\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-30B-A3B-4bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-30B-A3B-4bit\"\nn_layers = 48\nhidden_size = 2048\nnum_key_value_heads = 4\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"4bit\"\nbase_model = \"Qwen3 30B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 17612931072\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-30B-A3B-8bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-30B-A3B-8bit\"\nn_layers = 48\nhidden_size = 2048\nnum_key_value_heads = 4\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"8bit\"\nbase_model = \"Qwen3 30B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 33279705088\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-Coder-480B-A35B-Instruct-4bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit\"\nn_layers = 62\nhidden_size = 6144\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"4bit\"\nbase_model = \"Qwen3 Coder 480B\"\ncapabilities = [\"text\", \"code\"]\n\n[storage_size]\nin_bytes = 289910292480\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-Coder-480B-A35B-Instruct-8bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit\"\nn_layers = 62\nhidden_size = 6144\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"8bit\"\nbase_model = \"Qwen3 Coder 480B\"\ncapabilities = [\"text\", \"code\"]\n\n[storage_size]\nin_bytes = 579820584960\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-Coder-Next-4bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-Coder-Next-4bit\"\nn_layers = 48\nhidden_size = 2048\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"4bit\"\nbase_model = \"Qwen3 Coder Next\"\ncapabilities = [\"text\", \"code\"]\n\n[storage_size]\nin_bytes = 45644286500\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-Coder-Next-5bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-Coder-Next-5bit\"\nn_layers = 48\nhidden_size = 2048\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"5bit\"\nbase_model = \"Qwen3 Coder Next\"\ncapabilities = [\"text\", \"code\"]\n\n[storage_size]\nin_bytes = 57657697020\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-Coder-Next-6bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-Coder-Next-6bit\"\nn_layers = 48\nhidden_size = 2048\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"6bit\"\nbase_model = \"Qwen3 Coder Next\"\ncapabilities = [\"text\", \"code\"]\n\n[storage_size]\nin_bytes = 68899327465\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-Coder-Next-8bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-Coder-Next-8bit\"\nn_layers = 48\nhidden_size = 2048\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"8bit\"\nbase_model = \"Qwen3 Coder Next\"\ncapabilities = [\"text\", \"code\"]\n\n[storage_size]\nin_bytes = 89357758772\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-Coder-Next-bf16.toml",
    "content": "model_id = \"mlx-community/Qwen3-Coder-Next-bf16\"\nn_layers = 48\nhidden_size = 2048\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"bf16\"\nbase_model = \"Qwen3 Coder Next\"\ncapabilities = [\"text\", \"code\"]\n\n[storage_size]\nin_bytes = 157548627945\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-Next-80B-A3B-Instruct-4bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit\"\nn_layers = 48\nhidden_size = 2048\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"4bit\"\nbase_model = \"Qwen3 Next 80B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 46976204800\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-Next-80B-A3B-Instruct-8bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit\"\nn_layers = 48\nhidden_size = 2048\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"8bit\"\nbase_model = \"Qwen3 Next 80B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 88814387200\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-Next-80B-A3B-Thinking-4bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit\"\nn_layers = 48\nhidden_size = 2048\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"4bit\"\nbase_model = \"Qwen3 Next 80B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 47080074240\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3-Next-80B-A3B-Thinking-8bit.toml",
    "content": "model_id = \"mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit\"\nn_layers = 48\nhidden_size = 2048\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"8bit\"\nbase_model = \"Qwen3 Next 80B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 88814387200\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-122B-A10B-4bit.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-122B-A10B-4bit\"\nn_layers = 48\nhidden_size = 3072\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"4bit\"\nbase_model = \"Qwen3.5 122B A10B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 69593314272\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-122B-A10B-6bit.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-122B-A10B-6bit\"\nn_layers = 48\nhidden_size = 3072\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"6bit\"\nbase_model = \"Qwen3.5 122B A10B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 100120675296\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-122B-A10B-8bit.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-122B-A10B-8bit\"\nn_layers = 48\nhidden_size = 3072\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"8bit\"\nbase_model = \"Qwen3.5 122B A10B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 130648036320\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-122B-A10B-bf16.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-122B-A10B-bf16\"\nn_layers = 48\nhidden_size = 3072\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"bf16\"\nbase_model = \"Qwen3.5 122B A10B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 245125640160\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-27B-4bit.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-27B-4bit\"\nn_layers = 64\nhidden_size = 5120\nnum_key_value_heads = 4\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"4bit\"\nbase_model = \"Qwen3.5 27B\"\ncapabilities = [\"text\", \"thinking\"]\n\n[storage_size]\nin_bytes = 16054266848\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-27B-8bit.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-27B-8bit\"\nn_layers = 64\nhidden_size = 5120\nnum_key_value_heads = 4\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"8bit\"\nbase_model = \"Qwen3.5 27B\"\ncapabilities = [\"text\", \"thinking\"]\n\n[storage_size]\nin_bytes = 29500943328\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-2B-MLX-8bit.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-2B-MLX-8bit\"\nn_layers = 24\nhidden_size = 2048\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"8bit\"\nbase_model = \"Qwen3.5 2B\"\ncapabilities = [\"text\", \"thinking\"]\n\n[storage_size]\nin_bytes = 2662787264\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-35B-A3B-4bit.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-35B-A3B-4bit\"\nn_layers = 40\nhidden_size = 2048\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"4bit\"\nbase_model = \"Qwen3.5 35B A3B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 20391405152\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-35B-A3B-8bit.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-35B-A3B-8bit\"\nn_layers = 40\nhidden_size = 2048\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"8bit\"\nbase_model = \"Qwen3.5 35B A3B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 37721130592\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-397B-A17B-4bit.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-397B-A17B-4bit\"\nn_layers = 60\nhidden_size = 4096\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"4bit\"\nbase_model = \"Qwen3.5 397B A17B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 223860768352\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-397B-A17B-6bit.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-397B-A17B-6bit\"\nn_layers = 60\nhidden_size = 4096\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"6bit\"\nbase_model = \"Qwen3.5 397B A17B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 322946674272\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-397B-A17B-8bit.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-397B-A17B-8bit\"\nn_layers = 60\nhidden_size = 4096\nnum_key_value_heads = 2\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"8bit\"\nbase_model = \"Qwen3.5 397B A17B\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 422032580192\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-9B-4bit.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-9B-4bit\"\nn_layers = 32\nhidden_size = 4096\nnum_key_value_heads = 4\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"4bit\"\nbase_model = \"Qwen3.5 9B\"\ncapabilities = [\"text\", \"thinking\"]\n\n[storage_size]\nin_bytes = 5950062560\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Qwen3.5-9B-8bit.toml",
    "content": "model_id = \"mlx-community/Qwen3.5-9B-8bit\"\nn_layers = 32\nhidden_size = 4096\nnum_key_value_heads = 4\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"qwen\"\nquantization = \"8bit\"\nbase_model = \"Qwen3.5 9B\"\ncapabilities = [\"text\", \"thinking\"]\n\n[storage_size]\nin_bytes = 10426433504\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Step-3.5-Flash-4bit.toml",
    "content": "model_id = \"mlx-community/Step-3.5-Flash-4bit\"\nn_layers = 45\nhidden_size = 4096\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"step\"\nquantization = \"4bit\"\nbase_model = \"Step 3.5 Flash\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 114572190076\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Step-3.5-Flash-6bit.toml",
    "content": "model_id = \"mlx-community/Step-3.5-Flash-6bit\"\nn_layers = 45\nhidden_size = 4096\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"step\"\nquantization = \"6bit\"\nbase_model = \"Step 3.5 Flash\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 159039627774\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--Step-3.5-Flash-8Bit.toml",
    "content": "model_id = \"mlx-community/Step-3.5-Flash-8Bit\"\nn_layers = 45\nhidden_size = 4096\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"step\"\nquantization = \"8bit\"\nbase_model = \"Step 3.5 Flash\"\ncapabilities = [\"text\", \"thinking\", \"thinking_toggle\"]\n\n[storage_size]\nin_bytes = 209082699847\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--gpt-oss-120b-MXFP4-Q8.toml",
    "content": "model_id = \"mlx-community/gpt-oss-120b-MXFP4-Q8\"\nn_layers = 36\nhidden_size = 2880\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"gpt-oss\"\nquantization = \"MXFP4-Q8\"\nbase_model = \"GPT-OSS 120B\"\ncapabilities = [\"text\", \"thinking\"]\n\n[storage_size]\nin_bytes = 70652212224\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--gpt-oss-20b-MXFP4-Q8.toml",
    "content": "model_id = \"mlx-community/gpt-oss-20b-MXFP4-Q8\"\nn_layers = 24\nhidden_size = 2880\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"gpt-oss\"\nquantization = \"MXFP4-Q8\"\nbase_model = \"GPT-OSS 20B\"\ncapabilities = [\"text\", \"thinking\"]\n\n[storage_size]\nin_bytes = 12025908224\n"
  },
  {
    "path": "resources/inference_model_cards/mlx-community--llama-3.3-70b-instruct-fp16.toml",
    "content": "model_id = \"mlx-community/llama-3.3-70b-instruct-fp16\"\nn_layers = 80\nhidden_size = 8192\nnum_key_value_heads = 8\nsupports_tensor = true\ntasks = [\"TextGeneration\"]\nfamily = \"llama\"\nquantization = \"fp16\"\nbase_model = \"Llama 3.3 70B\"\ncapabilities = [\"text\"]\n\n[storage_size]\nin_bytes = 144383672320\n"
  },
  {
    "path": "rust/exo_pyo3_bindings/Cargo.toml",
    "content": "[package]\nname = \"exo_pyo3_bindings\"\nversion = { workspace = true }\nedition = { workspace = true }\npublish = false\n\n[lib]\ndoctest = false\npath = \"src/lib.rs\"\nname = \"exo_pyo3_bindings\"\n\n# \"cdylib\" needed to produce shared library for Python to import\n# \"rlib\" needed for stub-gen to run\ncrate-type = [\"cdylib\", \"rlib\"]\n\n[[bin]]\npath = \"src/bin/stub_gen.rs\"\nname = \"stub_gen\"\ndoc = false\n\n[lints]\nworkspace = true\n\n[dependencies]\nnetworking = { workspace = true }\n\n# interop\npyo3 = { version = \"0.27.2\", features = [\n  # \"abi3-py313\", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13\n  # \"nightly\", # enables better-supported GIL integration\n  \"experimental-async\", # async support in #[pyfunction] & #[pymethods]\n  #\"experimental-inspect\", # inspection of generated binary => easier to automate type-hint generation\n  #\"py-clone\", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)\n  # \"multiple-pymethods\", # allows multiple #[pymethods] sections per class\n\n  # integrations with other libraries\n  # \"arc_lock\", \"bigdecimal\", \"either\", \"hashbrown\", \"indexmap\", \"num-bigint\", \"num-complex\", \"num-rational\",\n  # \"ordered-float\", \"rust_decimal\", \"smallvec\",\n  # \"anyhow\", \"chrono\", \"chrono-local\", \"chrono-tz\", \"eyre\", \"jiff-02\", \"lock_api\", \"parking-lot\", \"time\",  \"serde\",\n] }\npyo3-stub-gen = { version = \"0.17.2\" }\npyo3-async-runtimes = { version = \"0.27.0\", features = [\n  \"attributes\",\n  \"tokio-runtime\",\n  \"testing\",\n] }\npyo3-log = \"0.13.2\"\n\n# macro dependencies\nextend = { workspace = true }\ndelegate = { workspace = true }\n\n# async runtime\ntokio = { workspace = true, features = [\"full\", \"tracing\"] }\nfutures-lite = { workspace = true }\n\n# utility dependencies\nutil = { workspace = true }\n\n# Tracing\nlog = { workspace = true }\nenv_logger = \"0.11\"\n\n# Networking\nlibp2p = { workspace = true, features = [\"full\"] }\npin-project = \"1.1.10\"\n"
  },
  {
    "path": "rust/exo_pyo3_bindings/README.md",
    "content": "TODO: do something here....\n"
  },
  {
    "path": "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi",
    "content": "# This file is automatically generated by pyo3_stub_gen\n# ruff: noqa: E501, F401\n\nimport builtins\nimport typing\n\n@typing.final\nclass AllQueuesFullError(builtins.Exception):\n    def __new__(cls, *args: typing.Any) -> AllQueuesFullError: ...\n    def __repr__(self) -> builtins.str: ...\n    def __str__(self) -> builtins.str: ...\n\n@typing.final\nclass Keypair:\n    r\"\"\"\n    Identity keypair of a node.\n    \"\"\"\n    @staticmethod\n    def generate() -> Keypair:\n        r\"\"\"\n        Generate a new Ed25519 keypair.\n        \"\"\"\n    @staticmethod\n    def from_bytes(bytes: bytes) -> Keypair:\n        r\"\"\"\n        Construct an Ed25519 keypair from secret key bytes\n        \"\"\"\n    def to_bytes(self) -> bytes:\n        r\"\"\"\n        Get the secret key bytes underlying the keypair\n        \"\"\"\n    def to_node_id(self) -> builtins.str:\n        r\"\"\"\n        Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.\n        \"\"\"\n\n@typing.final\nclass MessageTooLargeError(builtins.Exception):\n    def __new__(cls, *args: typing.Any) -> MessageTooLargeError: ...\n    def __repr__(self) -> builtins.str: ...\n    def __str__(self) -> builtins.str: ...\n\n@typing.final\nclass NetworkingHandle:\n    def __new__(cls, identity: Keypair) -> NetworkingHandle: ...\n    async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:\n        r\"\"\"\n        Subscribe to a `GossipSub` topic.\n        \n        Returns `True` if the subscription worked. Returns `False` if we were already subscribed.\n        \"\"\"\n    async def gossipsub_unsubscribe(self, topic: builtins.str) -> builtins.bool:\n        r\"\"\"\n        Unsubscribes from a `GossipSub` topic.\n        \n        Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.\n        \"\"\"\n    async def gossipsub_publish(self, topic: builtins.str, data: bytes) -> None:\n        r\"\"\"\n        Publishes a message with multiple topics to the `GossipSub` network.\n        \n        If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.\n        \"\"\"\n    async def recv(self) -> PyFromSwarm: ...\n\n@typing.final\nclass NoPeersSubscribedToTopicError(builtins.Exception):\n    def __new__(cls, *args: typing.Any) -> NoPeersSubscribedToTopicError: ...\n    def __repr__(self) -> builtins.str: ...\n    def __str__(self) -> builtins.str: ...\n\nclass PyFromSwarm:\n    @typing.final\n    class Connection(PyFromSwarm):\n        __match_args__ = (\"peer_id\", \"connected\",)\n        @property\n        def peer_id(self) -> builtins.str: ...\n        @property\n        def connected(self) -> builtins.bool: ...\n        def __new__(cls, peer_id: builtins.str, connected: builtins.bool) -> PyFromSwarm.Connection: ...\n    \n    @typing.final\n    class Message(PyFromSwarm):\n        __match_args__ = (\"origin\", \"topic\", \"data\",)\n        @property\n        def origin(self) -> builtins.str: ...\n        @property\n        def topic(self) -> builtins.str: ...\n        @property\n        def data(self) -> bytes: ...\n        def __new__(cls, origin: builtins.str, topic: builtins.str, data: bytes) -> PyFromSwarm.Message: ...\n    \n    ...\n\n"
  },
  {
    "path": "rust/exo_pyo3_bindings/pyproject.toml",
    "content": "[build-system]\nrequires = [\"maturin>=1.0,<2.0\"]\nbuild-backend = \"maturin\"\n\n[project]\nname = \"exo_pyo3_bindings\"\nversion = \"0.2.1\"\ndescription = \"Add your description here\"\nreadme = \"README.md\"\nauthors = [\n  { name = \"Andrei Cravtov\", email = \"the.andrei.cravtov@gmail.com\" },\n  { name = \"Evan Quiney\", email = \"evanev7@gmail.com\" },\n]\nrequires-python = \">=3.13\"\ndependencies = []\n\n[dependency-groups]\ndev = [\"exo_pyo3_bindings\", \"pytest>=8.4.0\", \"pytest-asyncio>=1.0.0\"]\n\n[tool.maturin]\n#purelib = true\n#python-source = \"python\"\nmodule-name = \"exo_pyo3_bindings\"\nfeatures = [\"pyo3/extension-module\", \"pyo3/experimental-async\"]\n\n[tool.pytest.ini_options]\nlog_cli = true\nlog_cli_level = \"INFO\"\nasyncio_mode = \"auto\"\n"
  },
  {
    "path": "rust/exo_pyo3_bindings/src/allow_threading.rs",
    "content": "//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await\n//!\n\nuse pin_project::pin_project;\nuse pyo3::prelude::*;\nuse std::{\n    future::Future,\n    pin::Pin,\n    task::{Context, Poll},\n};\n\n/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await\n#[pin_project]\n#[repr(transparent)]\npub(crate) struct AllowThreads<F>(#[pin] F);\n\nimpl<F> AllowThreads<F>\nwhere\n    Self: Future,\n{\n    pub fn new(f: F) -> Self {\n        Self(f)\n    }\n}\n\nimpl<F> Future for AllowThreads<F>\nwhere\n    F: Future + Send,\n    F::Output: Send,\n{\n    type Output = F::Output;\n\n    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {\n        let waker = cx.waker();\n        Python::attach(|py| py.detach(|| self.project().0.poll(&mut Context::from_waker(waker))))\n    }\n}\n"
  },
  {
    "path": "rust/exo_pyo3_bindings/src/bin/stub_gen.rs",
    "content": "use pyo3_stub_gen::Result;\n\nfn main() -> Result<()> {\n    env_logger::Builder::from_env(env_logger::Env::default().filter_or(\"RUST_LOG\", \"info\")).init();\n    let stub = exo_pyo3_bindings::stub_info()?;\n    stub.generate()?;\n    Ok(())\n}\n"
  },
  {
    "path": "rust/exo_pyo3_bindings/src/ident.rs",
    "content": "use crate::ext::ResultExt as _;\nuse libp2p::identity::Keypair;\nuse pyo3::types::{PyBytes, PyBytesMethods as _};\nuse pyo3::{Bound, PyResult, Python, pyclass, pymethods};\nuse pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};\n\n/// Identity keypair of a node.\n#[gen_stub_pyclass]\n#[pyclass(name = \"Keypair\", frozen)]\n#[repr(transparent)]\npub struct PyKeypair(pub Keypair);\n\n#[gen_stub_pymethods]\n#[pymethods]\n#[allow(clippy::needless_pass_by_value)]\nimpl PyKeypair {\n    /// Generate a new Ed25519 keypair.\n    #[staticmethod]\n    fn generate() -> Self {\n        Self(Keypair::generate_ed25519())\n    }\n\n    /// Construct an Ed25519 keypair from secret key bytes\n    #[staticmethod]\n    fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {\n        let mut bytes = Vec::from(bytes.as_bytes());\n        Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))\n    }\n\n    /// Get the secret key bytes underlying the keypair\n    fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {\n        let bytes = self\n            .0\n            .clone()\n            .try_into_ed25519()\n            .pyerr()?\n            .secret()\n            .as_ref()\n            .to_vec();\n        Ok(PyBytes::new(py, &bytes))\n    }\n\n    /// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.\n    fn to_node_id(&self) -> String {\n        self.0.public().to_peer_id().to_base58()\n    }\n}\n"
  },
  {
    "path": "rust/exo_pyo3_bindings/src/lib.rs",
    "content": "//! TODO: crate documentation\n//!\n//! this is here as a placeholder documentation\n//!\n//!\n\nmod allow_threading;\nmod ident;\nmod networking;\n\nuse crate::ident::PyKeypair;\nuse crate::networking::networking_submodule;\nuse pyo3::prelude::PyModule;\nuse pyo3::types::PyModuleMethods;\nuse pyo3::{Bound, PyResult, pyclass, pymodule};\nuse pyo3_stub_gen::define_stub_info_gatherer;\n\n/// Namespace for all the constants used by this crate.\npub(crate) mod r#const {\n    pub const MPSC_CHANNEL_SIZE: usize = 1024;\n}\n\n/// Namespace for crate-wide extension traits/methods\npub(crate) mod ext {\n    use crate::allow_threading::AllowThreads;\n    use extend::ext;\n    use pyo3::exceptions::{PyConnectionError, PyRuntimeError};\n    use pyo3::types::PyBytes;\n    use pyo3::{Py, PyErr, PyResult, Python};\n    use tokio::runtime::Runtime;\n    use tokio::sync::mpsc;\n    use tokio::sync::mpsc::error::TryRecvError;\n    use tokio::task::JoinHandle;\n\n    #[ext(pub, name = ByteArrayExt)]\n    impl [u8] {\n        fn pybytes(&self) -> Py<PyBytes> {\n            Python::attach(|py| PyBytes::new(py, self).unbind())\n        }\n    }\n\n    #[ext(pub, name = ResultExt)]\n    impl<T, E> Result<T, E>\n    where\n        E: ToString,\n    {\n        fn pyerr(self) -> PyResult<T> {\n            self.map_err(|e| PyRuntimeError::new_err(e.to_string()))\n        }\n    }\n\n    pub trait FutureExt: Future + Sized {\n        /// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await\n        fn allow_threads_py(self) -> AllowThreads<Self>\n        where\n            AllowThreads<Self>: Future,\n        {\n            AllowThreads::new(self)\n        }\n    }\n\n    impl<T: Future> FutureExt for T {}\n\n    #[ext(pub, name = PyErrExt)]\n    impl PyErr {\n        fn receiver_channel_closed() -> Self {\n            PyConnectionError::new_err(\"Receiver channel closed unexpectedly\")\n        }\n    }\n\n    #[ext(pub, name = PyResultExt)]\n    impl<T> PyResult<T> {\n        fn write_unraisable(self) -> Option<T> {\n            Python::attach(|py| self.write_unraisable_with(py))\n        }\n\n        fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {\n            match self {\n                Ok(v) => Some(v),\n                Err(e) => {\n                    // write error back to python\n                    e.write_unraisable(py, None);\n                    None\n                }\n            }\n        }\n    }\n\n    #[ext(pub, name = TokioRuntimeExt)]\n    impl Runtime {\n        fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>\n        where\n            F: Future + Send + 'static,\n            F::Output: Send + 'static,\n        {\n            let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;\n            Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))\n        }\n    }\n\n    #[ext(pub, name = TokioMpscSenderExt)]\n    impl<T> mpsc::Sender<T> {\n        /// Sends a value, waiting until there is capacity.\n        ///\n        /// A successful send occurs when it is determined that the other end of the\n        /// channel has not hung up already. An unsuccessful send would be one where\n        /// the corresponding receiver has already been closed.\n        async fn send_py(&self, value: T) -> PyResult<()> {\n            self.send(value)\n                .await\n                .map_err(|_| PyErr::receiver_channel_closed())\n        }\n    }\n\n    #[ext(pub, name = TokioMpscReceiverExt)]\n    impl<T> mpsc::Receiver<T> {\n        /// Receives the next value for this receiver.\n        async fn recv_py(&mut self) -> PyResult<T> {\n            self.recv().await.ok_or_else(PyErr::receiver_channel_closed)\n        }\n\n        /// Receives at most `limit` values for this receiver and returns them.\n        ///\n        /// For `limit = 0`, an empty collection of messages will be returned immediately.\n        /// For `limit > 0`, if there are no messages in the channel's queue this method\n        /// will sleep until a message is sent.\n        async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {\n            // get updates from receiver channel\n            let mut updates = Vec::with_capacity(limit);\n            let received = self.recv_many(&mut updates, limit).await;\n\n            // if we received zero items, then the channel was unexpectedly closed\n            if limit != 0 && received == 0 {\n                return Err(PyErr::receiver_channel_closed());\n            }\n\n            Ok(updates)\n        }\n\n        /// Tries to receive the next value for this receiver.\n        fn try_recv_py(&mut self) -> PyResult<Option<T>> {\n            match self.try_recv() {\n                Ok(v) => Ok(Some(v)),\n                Err(TryRecvError::Empty) => Ok(None),\n                Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),\n            }\n        }\n    }\n}\n\n/// A Python module implemented in Rust. The name of this function must match\n/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to\n/// import the module.\n#[pymodule(name = \"exo_pyo3_bindings\")]\nfn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {\n    // install logger\n    pyo3_log::init();\n    let mut builder = tokio::runtime::Builder::new_multi_thread();\n    builder.enable_all();\n    pyo3_async_runtimes::tokio::init(builder);\n\n    // TODO: for now this is all NOT a submodule, but figure out how to make the submodule system\n    //       work with maturin, where the types generate correctly, in the right folder, without\n    //       too many importing issues...\n    m.add_class::<PyKeypair>()?;\n    networking_submodule(m)?;\n\n    // top-level constructs\n    // TODO: ...\n\n    Ok(())\n}\n\ndefine_stub_info_gatherer!(stub_info);\n"
  },
  {
    "path": "rust/exo_pyo3_bindings/src/networking.rs",
    "content": "use std::pin::Pin;\nuse std::sync::Arc;\n\nuse crate::r#const::MPSC_CHANNEL_SIZE;\nuse crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};\nuse crate::ext::{ResultExt as _, TokioMpscSenderExt as _};\nuse crate::ident::PyKeypair;\nuse crate::networking::exception::{\n    PyAllQueuesFullError, PyMessageTooLargeError, PyNoPeersSubscribedToTopicError,\n};\nuse crate::pyclass;\nuse futures_lite::{Stream, StreamExt as _};\nuse libp2p::gossipsub::PublishError;\nuse networking::swarm::{FromSwarm, ToSwarm, create_swarm};\nuse pyo3::exceptions::PyRuntimeError;\nuse pyo3::prelude::{PyModule, PyModuleMethods as _};\nuse pyo3::types::PyBytes;\nuse pyo3::{Bound, Py, PyAny, PyErr, PyResult, Python, pymethods};\nuse pyo3_stub_gen::derive::{\n    gen_methods_from_python, gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pymethods,\n};\nuse tokio::sync::{Mutex, mpsc, oneshot};\n\nmod exception {\n    use pyo3::types::PyTuple;\n    use pyo3::{exceptions::PyException, prelude::*};\n    use pyo3_stub_gen::derive::*;\n\n    #[gen_stub_pyclass]\n    #[pyclass(frozen, extends=PyException, name=\"NoPeersSubscribedToTopicError\")]\n    pub struct PyNoPeersSubscribedToTopicError {}\n\n    impl PyNoPeersSubscribedToTopicError {\n        const MSG: &'static str = \"\\\n        No peers are currently subscribed to receive messages on this topic. \\\n        Wait for peers to subscribe or check your network connectivity.\";\n\n        ///   Creates a new  [ `PyErr` ]  of this type.\n        ///\n        ///   [`PyErr`] :  https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html   \"PyErr in pyo3\"\n        pub(crate) fn new_err() -> PyErr {\n            PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???\n        }\n    }\n\n    #[gen_stub_pymethods]\n    #[pymethods]\n    impl PyNoPeersSubscribedToTopicError {\n        #[new]\n        #[pyo3(signature = (*args))]\n        #[allow(unused_variables)]\n        pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {\n            Self {}\n        }\n\n        fn __repr__(&self) -> String {\n            format!(\"PeerId(\\\"{}\\\")\", Self::MSG)\n        }\n\n        fn __str__(&self) -> String {\n            Self::MSG.to_string()\n        }\n    }\n\n    #[gen_stub_pyclass]\n    #[pyclass(frozen, extends=PyException, name=\"AllQueuesFullError\")]\n    pub struct PyAllQueuesFullError {}\n\n    impl PyAllQueuesFullError {\n        const MSG: &'static str =\n            \"All libp2p peers are unresponsive, resend the message or reconnect.\";\n\n        ///   Creates a new  [ `PyErr` ]  of this type.\n        ///\n        ///   [`PyErr`] :  https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html   \"PyErr in pyo3\"\n        pub(crate) fn new_err() -> PyErr {\n            PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???\n        }\n    }\n\n    #[gen_stub_pymethods]\n    #[pymethods]\n    impl PyAllQueuesFullError {\n        #[new]\n        #[pyo3(signature = (*args))]\n        #[allow(unused_variables)]\n        pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {\n            Self {}\n        }\n\n        fn __repr__(&self) -> String {\n            format!(\"PeerId(\\\"{}\\\")\", Self::MSG)\n        }\n\n        fn __str__(&self) -> String {\n            Self::MSG.to_string()\n        }\n    }\n\n    #[gen_stub_pyclass]\n    #[pyclass(frozen, extends=PyException, name=\"MessageTooLargeError\")]\n    pub struct PyMessageTooLargeError {}\n\n    impl PyMessageTooLargeError {\n        const MSG: &'static str = \"Gossipsub message exceeds max_transmit_size. Reduce prompt length or increase the limit.\";\n\n        pub(crate) fn new_err() -> PyErr {\n            PyErr::new::<Self, _>(())\n        }\n    }\n\n    #[gen_stub_pymethods]\n    #[pymethods]\n    impl PyMessageTooLargeError {\n        #[new]\n        #[pyo3(signature = (*args))]\n        #[allow(unused_variables)]\n        pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {\n            Self {}\n        }\n\n        fn __repr__(&self) -> String {\n            format!(\"MessageTooLargeError(\\\"{}\\\")\", Self::MSG)\n        }\n\n        fn __str__(&self) -> String {\n            Self::MSG.to_string()\n        }\n    }\n}\n\n#[gen_stub_pyclass]\n#[pyclass(name = \"NetworkingHandle\")]\nstruct PyNetworkingHandle {\n    // channels\n    pub to_swarm: mpsc::Sender<ToSwarm>,\n    pub swarm: Arc<Mutex<Pin<Box<dyn Stream<Item = FromSwarm> + Send>>>>,\n}\n\n#[gen_stub_pyclass_complex_enum]\n#[pyclass]\nenum PyFromSwarm {\n    Connection {\n        peer_id: String,\n        connected: bool,\n    },\n    Message {\n        origin: String,\n        topic: String,\n        data: Py<PyBytes>,\n    },\n}\nimpl From<FromSwarm> for PyFromSwarm {\n    fn from(value: FromSwarm) -> Self {\n        match value {\n            FromSwarm::Discovered { peer_id } => Self::Connection {\n                peer_id: peer_id.to_base58(),\n                connected: true,\n            },\n            FromSwarm::Expired { peer_id } => Self::Connection {\n                peer_id: peer_id.to_base58(),\n                connected: false,\n            },\n            FromSwarm::Message { from, topic, data } => Self::Message {\n                origin: from.to_base58(),\n                topic: topic,\n                data: data.pybytes(),\n            },\n        }\n    }\n}\n\n#[gen_stub_pymethods]\n#[pymethods]\nimpl PyNetworkingHandle {\n    // NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()`\n    //       immediately beforehand to release the interpreter.\n    //       SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await\n\n    // ---- Lifecycle management methods ----\n\n    #[new]\n    fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {\n        // create communication channels\n        let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);\n\n        // get identity\n        let identity = identity.borrow().0.clone();\n\n        // create networking swarm (within tokio context!! or it crashes)\n        let _guard = pyo3_async_runtimes::tokio::get_runtime().enter();\n        let swarm = create_swarm(identity, from_client).pyerr()?.into_stream();\n\n        Ok(Self {\n            swarm: Arc::new(Mutex::new(swarm)),\n            to_swarm,\n        })\n    }\n\n    #[gen_stub(skip)]\n    fn recv<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {\n        let swarm = Arc::clone(&self.swarm);\n        pyo3_async_runtimes::tokio::future_into_py(py, async move {\n            swarm\n                .try_lock()\n                .map_err(|_| PyRuntimeError::new_err(\"called recv twice concurrently\"))?\n                .next()\n                .await\n                .ok_or(PyErr::receiver_channel_closed())\n                .map(PyFromSwarm::from)\n        })\n    }\n\n    // ---- Gossipsub management methods ----\n\n    /// Subscribe to a `GossipSub` topic.\n    ///\n    /// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.\n    async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {\n        let (tx, rx) = oneshot::channel();\n\n        // send off request to subscribe\n        self.to_swarm\n            .send_py(ToSwarm::Subscribe {\n                topic,\n                result_sender: tx,\n            })\n            .allow_threads_py() // allow-threads-aware async call\n            .await?;\n\n        // wait for response & return any errors\n        rx.allow_threads_py() // allow-threads-aware async call\n            .await\n            .map_err(|_| PyErr::receiver_channel_closed())?\n            .pyerr()\n    }\n\n    /// Unsubscribes from a `GossipSub` topic.\n    ///\n    /// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.\n    async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {\n        let (tx, rx) = oneshot::channel();\n\n        // send off request to unsubscribe\n        self.to_swarm\n            .send_py(ToSwarm::Unsubscribe {\n                topic,\n                result_sender: tx,\n            })\n            .allow_threads_py() // allow-threads-aware async call\n            .await?;\n\n        // wait for response & convert any errors\n        rx.allow_threads_py() // allow-threads-aware async call\n            .await\n            .map_err(|_| PyErr::receiver_channel_closed())\n    }\n\n    /// Publishes a message with multiple topics to the `GossipSub` network.\n    ///\n    /// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.\n    async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {\n        let (tx, rx) = oneshot::channel();\n\n        // send off request to subscribe\n        let data = Python::attach(|py| Vec::from(data.as_bytes(py)));\n        self.to_swarm\n            .send_py(ToSwarm::Publish {\n                topic,\n                data,\n                result_sender: tx,\n            })\n            .allow_threads_py() // allow-threads-aware async call\n            .await?;\n\n        // wait for response & return any errors => ignore messageID for now!!!\n        let _ = rx\n            .allow_threads_py() // allow-threads-aware async call\n            .await\n            .map_err(|_| PyErr::receiver_channel_closed())?\n            .map_err(|e| match e {\n                PublishError::AllQueuesFull(_) => PyAllQueuesFullError::new_err(),\n                PublishError::MessageTooLarge => PyMessageTooLargeError::new_err(),\n                PublishError::NoPeersSubscribedToTopic => {\n                    PyNoPeersSubscribedToTopicError::new_err()\n                }\n                e => PyRuntimeError::new_err(e.to_string()),\n            })?;\n        Ok(())\n    }\n}\n\npyo3_stub_gen::inventory::submit! {\n    gen_methods_from_python! {\n        r#\"\n            class PyNetworkingHandle:\n                async def recv() -> PyFromSwarm: ...\n        \"#\n    }\n}\n\npub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {\n    m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;\n    m.add_class::<exception::PyAllQueuesFullError>()?;\n    m.add_class::<exception::PyMessageTooLargeError>()?;\n\n    m.add_class::<PyNetworkingHandle>()?;\n    m.add_class::<PyFromSwarm>()?;\n\n    Ok(())\n}\n"
  },
  {
    "path": "rust/exo_pyo3_bindings/tests/dummy.rs",
    "content": "#[cfg(test)]\nmod tests {\n    use core::mem::drop;\n    use core::option::Option::Some;\n    use core::time::Duration;\n    use tokio;\n    use tokio::sync::mpsc;\n\n    #[tokio::test]\n    async fn test_drop_channel() {\n        struct Ping;\n\n        let (tx, mut rx) = mpsc::channel::<Ping>(10);\n\n        let _ = tokio::spawn(async move {\n            println!(\"TASK: entered\");\n\n            loop {\n                tokio::select! {\n                    result = rx.recv() => {\n                        match result {\n                            Some(_) => {\n                                println!(\"TASK: pinged\");\n                            }\n                            None => {\n                                println!(\"TASK: closing channel\");\n                                break;\n                            }\n                        }\n                    }\n                    _ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => {\n                        println!(\"TASK: heartbeat\");\n                    }\n                }\n            }\n\n            println!(\"TASK: exited\");\n        });\n\n        let tx2 = tx.clone();\n\n        tokio::time::sleep(Duration::from_secs_f32(0.11)).await;\n\n        tx.send(Ping).await.expect(\"Should not fail\");\n        drop(tx);\n\n        tokio::time::sleep(Duration::from_secs_f32(0.11)).await;\n\n        tx2.send(Ping).await.expect(\"Should not fail\");\n        drop(tx2);\n\n        tokio::time::sleep(Duration::from_secs_f32(0.11)).await;\n    }\n}\n"
  },
  {
    "path": "rust/exo_pyo3_bindings/tests/test_python.py",
    "content": "import asyncio\n\nimport pytest\nfrom exo_pyo3_bindings import (\n    Keypair,\n    NetworkingHandle,\n    NoPeersSubscribedToTopicError,\n    PyFromSwarm,\n)\n\n\n@pytest.mark.asyncio\nasync def test_sleep_on_multiple_items() -> None:\n    print(\"PYTHON: starting handle\")\n    h = NetworkingHandle(Keypair.generate())\n\n    rt = asyncio.create_task(_await_recv(h))\n\n    # sleep for 4 ticks\n    for i in range(4):\n        await asyncio.sleep(1)\n\n        try:\n            await h.gossipsub_publish(\"topic\", b\"somehting or other\")\n        except NoPeersSubscribedToTopicError as e:\n            print(\"caught it\", e)\n\n\nasync def _await_recv(h: NetworkingHandle):\n    while True:\n        event = await h.recv()\n        match event:\n            case PyFromSwarm.Connection() as c:\n                print(f\"PYTHON: connection update: {c}\")\n            case PyFromSwarm.Message() as m:\n                print(f\"PYTHON: message: {m}\")\n"
  },
  {
    "path": "rust/networking/Cargo.toml",
    "content": "[package]\nname = \"networking\"\nversion = { workspace = true }\nedition = { workspace = true }\npublish = false\n\n[lib]\ndoctest = false\nname = \"networking\"\npath = \"src/lib.rs\"\n\n[lints]\nworkspace = true\n\n[dependencies]\n# datastructures\neither = { workspace = true }\n\n# macro dependencies\nextend = { workspace = true }\ndelegate = { workspace = true }\n\n# async\nasync-stream = { workspace = true }\nfutures-lite = { workspace = true }\nfutures-timer = { workspace = true }\ntokio = { workspace = true, features = [\"full\"] }\n\n# utility dependencies\nutil = { workspace = true }\ntracing-subscriber = { version = \"0.3.19\", features = [\n  \"default\",\n  \"env-filter\",\n] }\nkeccak-const = { workspace = true }\n\n# tracing/logging\nlog = { workspace = true }\n\n# networking\nlibp2p = { workspace = true, features = [\"full\"] }\npin-project = \"1.1.10\"\n"
  },
  {
    "path": "rust/networking/examples/chatroom.rs",
    "content": "use futures_lite::StreamExt;\nuse libp2p::identity;\nuse networking::swarm;\nuse networking::swarm::{FromSwarm, ToSwarm};\nuse tokio::sync::{mpsc, oneshot};\nuse tokio::{io, io::AsyncBufReadExt as _};\nuse tracing_subscriber::EnvFilter;\nuse tracing_subscriber::filter::LevelFilter;\n\n#[tokio::main]\nasync fn main() {\n    let _ = tracing_subscriber::fmt()\n        .with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))\n        .try_init();\n\n    let (to_swarm, from_client) = mpsc::channel(20);\n\n    // Configure swarm\n    let mut swarm = swarm::create_swarm(identity::Keypair::generate_ed25519(), from_client)\n        .expect(\"Swarm creation failed\")\n        .into_stream();\n\n    // Create a Gossipsub topic & subscribe\n    let (tx, rx) = oneshot::channel();\n    _ = to_swarm\n        .send(ToSwarm::Subscribe {\n            topic: \"test-net\".to_string(),\n            result_sender: tx,\n        })\n        .await\n        .expect(\"should send\");\n\n    // Read full lines from stdin\n    let mut stdin = io::BufReader::new(io::stdin()).lines();\n    println!(\"Enter messages via STDIN and they will be sent to connected peers using Gossipsub\");\n\n    tokio::task::spawn(async move {\n        rx.await\n            .expect(\"tx not dropped\")\n            .expect(\"subscribe shouldn't fail\");\n        loop {\n            if let Ok(Some(line)) = stdin.next_line().await {\n                let (tx, rx) = oneshot::channel();\n                if let Err(e) = to_swarm\n                    .send(swarm::ToSwarm::Publish {\n                        topic: \"test-net\".to_string(),\n                        data: line.as_bytes().to_vec(),\n                        result_sender: tx,\n                    })\n                    .await\n                {\n                    println!(\"Send error: {e:?}\");\n                    return;\n                };\n                match rx.await {\n                    Ok(Err(e)) => println!(\"Publish error: {e:?}\"),\n                    Err(e) => println!(\"Publish error: {e:?}\"),\n                    Ok(_) => {}\n                }\n            }\n        }\n    });\n\n    // Kick it off\n    loop {\n        // on gossipsub outgoing\n        match swarm.next().await {\n            // on gossipsub incoming\n            Some(FromSwarm::Discovered { peer_id }) => {\n                println!(\"\\n\\nconnected to {peer_id}\\n\\n\")\n            }\n            Some(FromSwarm::Expired { peer_id }) => {\n                println!(\"\\n\\ndisconnected from {peer_id}\\n\\n\")\n            }\n            Some(FromSwarm::Message { from, topic, data }) => {\n                println!(\"{topic}/{from}:\\n{}\", String::from_utf8_lossy(&data))\n            }\n            None => {}\n        }\n    }\n}\n"
  },
  {
    "path": "rust/networking/src/RESEARCH_NOTES.txt",
    "content": "https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609\n^raw sockets distributed -> `<net/ndrv.h>` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html\n--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets.\n--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received.\n--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools.\n--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html\n----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET\n\nhttps://stackoverflow.com/questions/17169298/af-packet-on-osx\n^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only\n^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing\n\nhttps://www.unix.com/man_page/mojave/4/ip/\n^OS X manpages for IP\n\nhttps://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts\n^driver kit, system extensions & kexts for macOS\n\n----\n\nTo set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network.\n--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a\n----> here is a guide on how to set up thunderbolt-ethernet on linux\n----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS\n\nhttps://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7\n^GPT discussion about making socket interface\n\nhttps://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,??\nhttps://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2\n^GPT discussion about accessing TB on MacOS low level interactions\n\n--------------------------------\n\nhttps://www.intel.com/content/www/us/en/support/articles/000098893/software.html\n^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge\n\n\n---------------------------------\n\nhttps://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/\n-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet??\n-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!!\n-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c"
  },
  {
    "path": "rust/networking/src/discovery.rs",
    "content": "use crate::ext::MultiaddrExt;\nuse delegate::delegate;\nuse either::Either;\nuse futures_lite::FutureExt;\nuse futures_timer::Delay;\nuse libp2p::core::transport::PortUse;\nuse libp2p::core::{ConnectedPoint, Endpoint};\nuse libp2p::swarm::behaviour::ConnectionEstablished;\nuse libp2p::swarm::dial_opts::DialOpts;\nuse libp2p::swarm::{\n    CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,\n    ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,\n    THandlerOutEvent, ToSwarm, dummy,\n};\nuse libp2p::{Multiaddr, PeerId, identity, mdns};\nuse std::collections::{BTreeSet, HashMap};\nuse std::convert::Infallible;\nuse std::io;\nuse std::net::IpAddr;\nuse std::task::{Context, Poll};\nuse std::time::Duration;\nuse util::wakerdeque::WakerDeque;\n\nconst RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);\n\nmod managed {\n    use libp2p::swarm::NetworkBehaviour;\n    use libp2p::{identity, mdns, ping};\n    use std::io;\n    use std::time::Duration;\n\n    const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);\n    const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);\n    const PING_TIMEOUT: Duration = Duration::from_millis(2_500);\n    const PING_INTERVAL: Duration = Duration::from_millis(2_500);\n\n    #[derive(NetworkBehaviour)]\n    pub struct Behaviour {\n        mdns: mdns::tokio::Behaviour,\n        ping: ping::Behaviour,\n    }\n\n    impl Behaviour {\n        pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {\n            Ok(Self {\n                mdns: mdns_behaviour(keypair)?,\n                ping: ping_behaviour(),\n            })\n        }\n    }\n\n    fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {\n        use mdns::{Config, tokio};\n\n        // mDNS config => enable IPv6\n        let mdns_config = Config {\n            ttl: MDNS_RECORD_TTL,\n            query_interval: MDNS_QUERY_INTERVAL,\n\n            // enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work\n            ..Default::default()\n        };\n\n        let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());\n        Ok(mdns_behaviour?)\n    }\n\n    fn ping_behaviour() -> ping::Behaviour {\n        ping::Behaviour::new(\n            ping::Config::new()\n                .with_timeout(PING_TIMEOUT)\n                .with_interval(PING_INTERVAL),\n        )\n    }\n}\n\n/// Events for when a listening connection is truly established and truly closed.\n#[derive(Debug, Clone)]\npub enum Event {\n    ConnectionEstablished {\n        peer_id: PeerId,\n        connection_id: ConnectionId,\n        remote_ip: IpAddr,\n        remote_tcp_port: u16,\n    },\n    ConnectionClosed {\n        peer_id: PeerId,\n        connection_id: ConnectionId,\n        remote_ip: IpAddr,\n        remote_tcp_port: u16,\n    },\n}\n\n/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.\n///\n/// The behaviour operates as such:\n///  1) All true (listening) connections/disconnections are tracked, emitting corresponding events\n///     to the swarm.\n///  1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed\n///     immediately, and expired but connected peers are disconnected from immediately.\n///  2) Every fixed interval: discovered but not connected peers are dialed, and expired but\n///     connected peers are disconnected from.\npub struct Behaviour {\n    // state-tracking for managed behaviors & mDNS-discovered peers\n    managed: managed::Behaviour,\n    mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,\n\n    retry_delay: Delay, // retry interval\n\n    // pending events to emmit => waker-backed Deque to control polling\n    pending_events: WakerDeque<ToSwarm<Event, Infallible>>,\n}\n\nimpl Behaviour {\n    pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {\n        Ok(Self {\n            managed: managed::Behaviour::new(keypair)?,\n            mdns_discovered: HashMap::new(),\n            retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),\n            pending_events: WakerDeque::new(),\n        })\n    }\n\n    fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {\n        self.pending_events.push_back(ToSwarm::Dial {\n            opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),\n        })\n    }\n\n    fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {\n        // push front to make this IMMEDIATE\n        self.pending_events.push_front(ToSwarm::CloseConnection {\n            peer_id,\n            connection: CloseConnection::One(connection),\n        })\n    }\n\n    fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {\n        for (p, ma) in peers {\n            self.dial(p, ma.clone()); // always connect\n\n            // get peer's multi-addresses or insert if missing\n            let Some(mas) = self.mdns_discovered.get_mut(&p) else {\n                self.mdns_discovered.insert(p, BTreeSet::from([ma]));\n                continue;\n            };\n\n            // multiaddress should never already be present - else something has gone wrong\n            let is_new_addr = mas.insert(ma);\n            assert!(is_new_addr, \"cannot discover a discovered peer\");\n        }\n    }\n\n    fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {\n        for (p, ma) in peers {\n            // at this point, we *must* have the peer\n            let mas = self\n                .mdns_discovered\n                .get_mut(&p)\n                .expect(\"nonexistent peer cannot expire\");\n\n            // at this point, we *must* have the multiaddress\n            let was_present = mas.remove(&ma);\n            assert!(was_present, \"nonexistent multiaddress cannot expire\");\n\n            // if empty, remove the peer-id entirely\n            if mas.is_empty() {\n                self.mdns_discovered.remove(&p);\n            }\n        }\n    }\n\n    fn on_connection_established(\n        &mut self,\n        peer_id: PeerId,\n        connection_id: ConnectionId,\n        remote_ip: IpAddr,\n        remote_tcp_port: u16,\n    ) {\n        // send out connected event\n        self.pending_events\n            .push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {\n                peer_id,\n                connection_id,\n                remote_ip,\n                remote_tcp_port,\n            }));\n    }\n\n    fn on_connection_closed(\n        &mut self,\n        peer_id: PeerId,\n        connection_id: ConnectionId,\n        remote_ip: IpAddr,\n        remote_tcp_port: u16,\n    ) {\n        // send out disconnected event\n        self.pending_events\n            .push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {\n                peer_id,\n                connection_id,\n                remote_ip,\n                remote_tcp_port,\n            }));\n    }\n}\n\nimpl NetworkBehaviour for Behaviour {\n    type ConnectionHandler =\n        ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;\n    type ToSwarm = Event;\n\n    // simply delegate to underlying mDNS behaviour\n\n    delegate! {\n        to self.managed {\n            fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;\n            fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;\n        }\n    }\n\n    fn handle_established_inbound_connection(\n        &mut self,\n        connection_id: ConnectionId,\n        peer: PeerId,\n        local_addr: &Multiaddr,\n        remote_addr: &Multiaddr,\n    ) -> Result<THandler<Self>, ConnectionDenied> {\n        Ok(ConnectionHandler::select(\n            dummy::ConnectionHandler,\n            self.managed.handle_established_inbound_connection(\n                connection_id,\n                peer,\n                local_addr,\n                remote_addr,\n            )?,\n        ))\n    }\n\n    #[allow(clippy::needless_question_mark)]\n    fn handle_established_outbound_connection(\n        &mut self,\n        connection_id: ConnectionId,\n        peer: PeerId,\n        addr: &Multiaddr,\n        role_override: Endpoint,\n        port_use: PortUse,\n    ) -> Result<THandler<Self>, ConnectionDenied> {\n        Ok(ConnectionHandler::select(\n            dummy::ConnectionHandler,\n            self.managed.handle_established_outbound_connection(\n                connection_id,\n                peer,\n                addr,\n                role_override,\n                port_use,\n            )?,\n        ))\n    }\n\n    fn on_connection_handler_event(\n        &mut self,\n        peer_id: PeerId,\n        connection_id: ConnectionId,\n        event: THandlerOutEvent<Self>,\n    ) {\n        match event {\n            Either::Left(ev) => libp2p::core::util::unreachable(ev),\n            Either::Right(ev) => {\n                self.managed\n                    .on_connection_handler_event(peer_id, connection_id, ev)\n            }\n        }\n    }\n\n    // hook into these methods to drive behavior\n\n    fn on_swarm_event(&mut self, event: FromSwarm) {\n        self.managed.on_swarm_event(event); // let mDNS handle swarm events\n\n        // handle swarm events to update internal state:\n        match event {\n            FromSwarm::ConnectionEstablished(ConnectionEstablished {\n                peer_id,\n                connection_id,\n                endpoint,\n                ..\n            }) => {\n                let remote_address = match endpoint {\n                    ConnectedPoint::Dialer { address, .. } => address,\n                    ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,\n                };\n\n                if let Some((ip, port)) = remote_address.try_to_tcp_addr() {\n                    // handle connection established event which is filtered correctly\n                    self.on_connection_established(peer_id, connection_id, ip, port)\n                }\n            }\n            FromSwarm::ConnectionClosed(ConnectionClosed {\n                peer_id,\n                connection_id,\n                endpoint,\n                ..\n            }) => {\n                let remote_address = match endpoint {\n                    ConnectedPoint::Dialer { address, .. } => address,\n                    ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,\n                };\n\n                if let Some((ip, port)) = remote_address.try_to_tcp_addr() {\n                    // handle connection closed event which is filtered correctly\n                    self.on_connection_closed(peer_id, connection_id, ip, port)\n                }\n            }\n\n            // since we are running TCP/IP transport layer, we are assuming that\n            // no address changes can occur, hence encountering one is a fatal error\n            FromSwarm::AddressChange(a) => {\n                unreachable!(\"unhandlable: address change encountered: {:?}\", a)\n            }\n            _ => {}\n        }\n    }\n\n    fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {\n        // delegate to managed behaviors for any behaviors they need to perform\n        match self.managed.poll(cx) {\n            Poll::Ready(ToSwarm::GenerateEvent(e)) => {\n                match e {\n                    // handle discovered and expired events from mDNS\n                    managed::BehaviourEvent::Mdns(e) => match e.clone() {\n                        mdns::Event::Discovered(peers) => {\n                            self.handle_mdns_discovered(peers);\n                        }\n                        mdns::Event::Expired(peers) => {\n                            self.handle_mdns_expired(peers);\n                        }\n                    },\n\n                    // handle ping events => if error then disconnect\n                    managed::BehaviourEvent::Ping(e) => {\n                        if let Err(_) = e.result {\n                            self.close_connection(e.peer, e.connection.clone())\n                        }\n                    }\n                }\n\n                // since we just consumed an event, we should immediately wake just in case\n                // there are more events to come where that came from\n                cx.waker().wake_by_ref();\n            }\n\n            // forward any other mDNS event to the swarm or its connection handler(s)\n            Poll::Ready(e) => {\n                return Poll::Ready(\n                    e.map_out(|_| unreachable!(\"events returning to swarm already handled\"))\n                        .map_in(Either::Right),\n                );\n            }\n\n            Poll::Pending => {}\n        }\n\n        // retry connecting to all mDNS peers periodically (fails safely if already connected)\n        if self.retry_delay.poll(cx).is_ready() {\n            for (p, mas) in self.mdns_discovered.clone() {\n                for ma in mas {\n                    self.dial(p, ma)\n                }\n            }\n            self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout\n        }\n\n        // send out any pending events from our own service\n        if let Some(e) = self.pending_events.pop_front(cx) {\n            return Poll::Ready(e.map_in(Either::Left));\n        }\n\n        // wait for pending events\n        Poll::Pending\n    }\n}\n"
  },
  {
    "path": "rust/networking/src/lib.rs",
    "content": "//! TODO: crate documentation\n//!\n//! this is here as a placeholder documentation\n//!\n//!\npub mod discovery;\npub mod swarm;\n\n/// Namespace for all the type/trait aliases used by this crate.\npub(crate) mod alias {\n    use std::error::Error;\n\n    pub type AnyError = Box<dyn Error + Send + Sync + 'static>;\n    pub type AnyResult<T> = Result<T, AnyError>;\n}\n\n/// Namespace for crate-wide extension traits/methods\npub(crate) mod ext {\n    use extend::ext;\n    use libp2p::Multiaddr;\n    use libp2p::multiaddr::Protocol;\n    use std::net::IpAddr;\n\n    #[ext(pub, name = MultiaddrExt)]\n    impl Multiaddr {\n        /// If the multiaddress corresponds to a TCP address, extracts it\n        fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> {\n            let mut ps = self.into_iter();\n            let ip = if let Some(p) = ps.next() {\n                match p {\n                    Protocol::Ip4(ip) => IpAddr::V4(ip),\n                    Protocol::Ip6(ip) => IpAddr::V6(ip),\n                    _ => return None,\n                }\n            } else {\n                return None;\n            };\n            let Some(Protocol::Tcp(port)) = ps.next() else {\n                return None;\n            };\n            Some((ip, port))\n        }\n    }\n}\n"
  },
  {
    "path": "rust/networking/src/swarm.rs",
    "content": "use std::pin::Pin;\n\nuse crate::swarm::transport::tcp_transport;\nuse crate::{alias, discovery};\npub use behaviour::{Behaviour, BehaviourEvent};\nuse futures_lite::{Stream, StreamExt};\nuse libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};\nuse tokio::sync::{mpsc, oneshot};\n\n/// The current version of the network: this prevents devices running different versions of the\n/// software from interacting with each other.\n///\n/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should\n///       even be, and how to inject the right version into this config/initialization. E.g. should\n///       this be passed in as a parameter? What about rapidly changing versions in debug builds?\n///       this is all VERY very hard to figure out and needs to be mulled over as a team.\npub const NETWORK_VERSION: &[u8] = b\"v0.0.1\";\npub const OVERRIDE_VERSION_ENV_VAR: &str = \"EXO_LIBP2P_NAMESPACE\";\n\n// Uses oneshot senders to emulate function calling apis while avoiding requiring unique ownership\n// of the Swarm.\npub enum ToSwarm {\n    Unsubscribe {\n        topic: String,\n        result_sender: oneshot::Sender<bool>,\n    },\n    Subscribe {\n        topic: String,\n        result_sender: oneshot::Sender<Result<bool, gossipsub::SubscriptionError>>,\n    },\n    Publish {\n        topic: String,\n        data: Vec<u8>,\n        result_sender: oneshot::Sender<Result<gossipsub::MessageId, gossipsub::PublishError>>,\n    },\n}\npub enum FromSwarm {\n    Message {\n        from: PeerId,\n        topic: String,\n        data: Vec<u8>,\n    },\n    Discovered {\n        peer_id: PeerId,\n    },\n    Expired {\n        peer_id: PeerId,\n    },\n}\n\npub struct Swarm {\n    swarm: libp2p::Swarm<Behaviour>,\n    from_client: mpsc::Receiver<ToSwarm>,\n}\n\nimpl Swarm {\n    pub fn into_stream(self) -> Pin<Box<dyn Stream<Item = FromSwarm> + Send>> {\n        let Swarm {\n            mut swarm,\n            mut from_client,\n        } = self;\n        let stream = async_stream::stream! {\n            loop {\n                tokio::select! {\n                    msg = from_client.recv() => {\n                        let Some(msg) = msg else { break };\n                        on_message(&mut swarm, msg);\n                    }\n                    event = swarm.next() => {\n                        let Some(event) = event else { break };\n                        if let Some(item) = filter_swarm_event(event) {\n                            yield item;\n                        }\n                    }\n                }\n            }\n        };\n        Box::pin(stream)\n    }\n}\n\nfn on_message(swarm: &mut libp2p::Swarm<Behaviour>, message: ToSwarm) {\n    match message {\n        ToSwarm::Subscribe {\n            topic,\n            result_sender,\n        } => {\n            let result = swarm\n                .behaviour_mut()\n                .gossipsub\n                .subscribe(&gossipsub::IdentTopic::new(topic));\n            _ = result_sender.send(result);\n        }\n        ToSwarm::Unsubscribe {\n            topic,\n            result_sender,\n        } => {\n            let result = swarm\n                .behaviour_mut()\n                .gossipsub\n                .unsubscribe(&gossipsub::IdentTopic::new(topic));\n            _ = result_sender.send(result);\n        }\n        ToSwarm::Publish {\n            topic,\n            data,\n            result_sender,\n        } => {\n            let result = swarm\n                .behaviour_mut()\n                .gossipsub\n                .publish(gossipsub::IdentTopic::new(topic), data);\n            _ = result_sender.send(result);\n        }\n    }\n}\n\nfn filter_swarm_event(event: SwarmEvent<BehaviourEvent>) -> Option<FromSwarm> {\n    match event {\n        SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(gossipsub::Event::Message {\n            message:\n                gossipsub::Message {\n                    source: Some(peer_id),\n                    topic,\n                    data,\n                    ..\n                },\n            ..\n        })) => Some(FromSwarm::Message {\n            from: peer_id,\n            topic: topic.into_string(),\n            data,\n        }),\n        SwarmEvent::Behaviour(BehaviourEvent::Discovery(\n            discovery::Event::ConnectionEstablished { peer_id, .. },\n        )) => Some(FromSwarm::Discovered { peer_id }),\n        SwarmEvent::Behaviour(BehaviourEvent::Discovery(discovery::Event::ConnectionClosed {\n            peer_id,\n            ..\n        })) => Some(FromSwarm::Expired { peer_id }),\n        _ => None,\n    }\n}\n\n/// Create and configure a swarm which listens to all ports on OS\npub fn create_swarm(\n    keypair: identity::Keypair,\n    from_client: mpsc::Receiver<ToSwarm>,\n) -> alias::AnyResult<Swarm> {\n    let mut swarm = SwarmBuilder::with_existing_identity(keypair)\n        .with_tokio()\n        .with_other_transport(tcp_transport)?\n        .with_behaviour(Behaviour::new)?\n        .build();\n\n    // Listen on all interfaces and whatever port the OS assigns\n    swarm.listen_on(\"/ip4/0.0.0.0/tcp/0\".parse()?)?;\n    Ok(Swarm { swarm, from_client })\n}\n\nmod transport {\n    use crate::alias;\n    use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};\n    use futures_lite::{AsyncRead, AsyncWrite};\n    use keccak_const::Sha3_256;\n    use libp2p::core::muxing;\n    use libp2p::core::transport::Boxed;\n    use libp2p::pnet::{PnetError, PnetOutput};\n    use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};\n    use std::{env, sync::LazyLock};\n\n    /// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].\n    /// See [`pnet_upgrade`] for more.\n    static PNET_PRESHARED_KEY: LazyLock<[u8; 32]> = LazyLock::new(|| {\n        let builder = Sha3_256::new().update(b\"exo_discovery_network\");\n\n        if let Ok(var) = env::var(OVERRIDE_VERSION_ENV_VAR) {\n            let bytes = var.into_bytes();\n            builder.update(&bytes)\n        } else {\n            builder.update(NETWORK_VERSION)\n        }\n        .finalize()\n    });\n\n    /// Make the Swarm run on a private network, as to not clash with public libp2p nodes and\n    /// also different-versioned instances of this same network.\n    /// This is implemented as an additional \"upgrade\" ontop of existing [`libp2p::Transport`] layers.\n    async fn pnet_upgrade<TSocket>(\n        socket: TSocket,\n        _: impl Sized,\n    ) -> Result<PnetOutput<TSocket>, PnetError>\n    where\n        TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,\n    {\n        use pnet::{PnetConfig, PreSharedKey};\n        PnetConfig::new(PreSharedKey::new(*PNET_PRESHARED_KEY))\n            .handshake(socket)\n            .await\n    }\n\n    /// TCP/IP transport layer configuration.\n    pub fn tcp_transport(\n        keypair: &identity::Keypair,\n    ) -> alias::AnyResult<Boxed<(PeerId, muxing::StreamMuxerBox)>> {\n        use libp2p::{\n            core::upgrade::Version,\n            tcp::{Config, tokio},\n        };\n\n        // `TCP_NODELAY` enabled => avoid latency\n        let tcp_config = Config::default().nodelay(true);\n\n        // V1 + lazy flushing => 0-RTT negotiation\n        let upgrade_version = Version::V1Lazy;\n\n        // Noise is faster than TLS + we don't care much for security\n        let noise_config = noise::Config::new(keypair)?;\n\n        // Use default Yamux config for multiplexing\n        let yamux_config = yamux::Config::default();\n\n        // Create new Tokio-driven TCP/IP transport layer\n        let base_transport = tokio::Transport::new(tcp_config)\n            .and_then(pnet_upgrade)\n            .upgrade(upgrade_version)\n            .authenticate(noise_config)\n            .multiplex(yamux_config);\n\n        // Return boxed transport (to flatten complex type)\n        Ok(base_transport.boxed())\n    }\n}\n\nmod behaviour {\n    use crate::{alias, discovery};\n    use libp2p::swarm::NetworkBehaviour;\n    use libp2p::{gossipsub, identity};\n\n    /// Behavior of the Swarm which composes all desired behaviors:\n    /// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].\n    #[derive(NetworkBehaviour)]\n    pub struct Behaviour {\n        pub discovery: discovery::Behaviour,\n        pub gossipsub: gossipsub::Behaviour,\n    }\n\n    impl Behaviour {\n        pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {\n            Ok(Self {\n                discovery: discovery::Behaviour::new(keypair)?,\n                gossipsub: gossipsub_behaviour(keypair),\n            })\n        }\n    }\n\n    fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {\n        use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};\n\n        // build a gossipsub network behaviour\n        //  => signed message authenticity + strict validation mode means the message-ID is\n        //     automatically provided by gossipsub w/out needing to provide custom message-ID function\n        gossipsub::Behaviour::new(\n            MessageAuthenticity::Signed(keypair.clone()),\n            ConfigBuilder::default()\n                .max_transmit_size(8 * 1024 * 1024)\n                .validation_mode(ValidationMode::Strict)\n                .build()\n                .expect(\"the configuration should always be valid\"),\n        )\n        .expect(\"creating gossipsub behavior should always work\")\n    }\n}\n"
  },
  {
    "path": "rust/networking/tests/dummy.rs",
    "content": "// maybe this will hold test in the future...??\n\n#[cfg(test)]\nmod tests {\n    #[test]\n    fn does_nothing() {}\n}\n"
  },
  {
    "path": "rust/parts.nix",
    "content": "{ inputs, ... }:\n{\n  perSystem =\n    { inputs', pkgs, lib, ... }:\n    let\n      # Fenix nightly toolchain with all components\n      rustToolchain = inputs'.fenix.packages.stable.withComponents [\n        \"cargo\"\n        \"rustc\"\n        \"clippy\"\n        \"rustfmt\"\n        \"rust-src\"\n        \"rust-analyzer\"\n      ];\n\n      # Crane with fenix toolchain\n      craneLib = (inputs.crane.mkLib pkgs).overrideToolchain rustToolchain;\n\n      # Source filtering - only include rust/ directory and root Cargo files\n      # This ensures changes to Python/docs/etc don't trigger Rust rebuilds\n      src = lib.cleanSourceWith {\n        src = inputs.self;\n        filter =\n          path: type:\n          let\n            baseName = builtins.baseNameOf path;\n            parentDir = builtins.dirOf path;\n            inRustDir =\n              (lib.hasInfix \"/rust/\" path)\n              || (lib.hasSuffix \"/rust\" parentDir)\n              || (baseName == \"rust\" && type == \"directory\");\n            isRootCargoFile =\n              (baseName == \"Cargo.toml\" || baseName == \"Cargo.lock\")\n              && (builtins.dirOf path == toString inputs.self);\n          in\n          isRootCargoFile\n          || (inRustDir && (craneLib.filterCargoSources path type || lib.hasSuffix \".toml\" path || lib.hasSuffix \".md\" path));\n      };\n\n      # Common arguments for all Rust builds\n      commonArgs = {\n        inherit src;\n        pname = \"exo-rust\";\n        version = \"0.0.1\";\n        strictDeps = true;\n\n        nativeBuildInputs = [\n          pkgs.pkg-config\n          pkgs.python313 # Required for pyo3-build-config\n        ];\n\n        buildInputs = [\n          pkgs.openssl\n          pkgs.python313 # Required for pyo3 tests\n        ];\n\n        OPENSSL_NO_VENDOR = \"1\";\n\n        # Required for pyo3 tests to find libpython\n        LD_LIBRARY_PATH = lib.makeLibraryPath [ pkgs.python313 ];\n      };\n\n      # Build dependencies once for caching\n      cargoArtifacts = craneLib.buildDepsOnly (\n        commonArgs\n        // {\n          cargoExtraArgs = \"--workspace\";\n        }\n      );\n    in\n    {\n      # Export toolchain for use in treefmt and devShell\n      options.rust = {\n        toolchain = lib.mkOption {\n          type = lib.types.package;\n          default = rustToolchain;\n          description = \"The Rust toolchain to use\";\n        };\n      };\n\n      config = {\n        packages = {\n          # Python bindings wheel via maturin\n          exo_pyo3_bindings = craneLib.buildPackage (\n            commonArgs\n            // {\n              inherit cargoArtifacts;\n              pname = \"exo_pyo3_bindings\";\n\n              nativeBuildInputs = commonArgs.nativeBuildInputs ++ [\n                pkgs.maturin\n              ];\n\n              buildPhaseCargoCommand = ''\n                maturin build \\\n                  --release \\\n                  --manylinux off \\\n                  --manifest-path rust/exo_pyo3_bindings/Cargo.toml \\\n                  --features \"pyo3/extension-module,pyo3/experimental-async\" \\\n                  --interpreter ${pkgs.python313}/bin/python \\\n                  --out dist\n              '';\n\n              # Don't use crane's default install behavior\n              doNotPostBuildInstallCargoBinaries = true;\n\n              installPhaseCommand = ''\n                mkdir -p $out\n                cp dist/*.whl $out/\n              '';\n            }\n          );\n        };\n\n        checks = {\n          # Full workspace build (all crates)\n          cargo-build = craneLib.buildPackage (\n            commonArgs\n            // {\n              inherit cargoArtifacts;\n              cargoExtraArgs = \"--workspace\";\n            }\n          );\n          # Run tests with nextest\n          cargo-nextest = craneLib.cargoNextest (\n            commonArgs\n            // {\n              inherit cargoArtifacts;\n              cargoExtraArgs = \"--workspace\";\n            }\n          );\n\n          # Build documentation\n          cargo-doc = craneLib.cargoDoc (\n            commonArgs\n            // {\n              inherit cargoArtifacts;\n              cargoExtraArgs = \"--workspace\";\n            }\n          );\n        };\n      };\n    };\n}\n"
  },
  {
    "path": "rust/util/Cargo.toml",
    "content": "[package]\nname = \"util\"\nversion = { workspace = true }\nedition = { workspace = true }\npublish = false\n\n[lib]\ndoctest = false\nname = \"util\"\npath = \"src/lib.rs\"\n\n[lints]\nworkspace = true\n\n[dependencies]\n"
  },
  {
    "path": "rust/util/src/lib.rs",
    "content": "pub mod wakerdeque;\n"
  },
  {
    "path": "rust/util/src/wakerdeque.rs",
    "content": "use std::collections::VecDeque;\nuse std::fmt::{Debug, Formatter};\nuse std::task::{Context, Waker};\n\n/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods,\n/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods.\npub struct WakerDeque<T> {\n    waker: Option<Waker>,\n    deque: VecDeque<T>,\n}\n\nimpl<T: Debug> Debug for WakerDeque<T> {\n    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {\n        self.deque.fmt(f)\n    }\n}\n\nimpl<T> WakerDeque<T> {\n    pub fn new() -> Self {\n        Self {\n            waker: None,\n            deque: VecDeque::new(),\n        }\n    }\n\n    fn update(&mut self, cx: &mut Context<'_>) {\n        self.waker = Some(cx.waker().clone());\n    }\n\n    fn wake(&mut self) {\n        let Some(ref mut w) = self.waker else { return };\n        w.wake_by_ref();\n        self.waker = None;\n    }\n\n    pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option<T> {\n        self.update(cx);\n        self.deque.pop_front()\n    }\n\n    pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option<T> {\n        self.update(cx);\n        self.deque.pop_back()\n    }\n\n    pub fn push_front(&mut self, value: T) {\n        self.wake();\n        self.deque.push_front(value);\n    }\n\n    pub fn push_back(&mut self, value: T) {\n        self.wake();\n        self.deque.push_back(value);\n    }\n}\n"
  },
  {
    "path": "scripts/fetch_kv_heads.py",
    "content": "#!/usr/bin/env python3\n\"\"\"Fetch num_key_value_heads from HuggingFace config.json and update TOML model cards.\n\nUsage:\n    # Update only cards missing num_key_value_heads\n    uv run python scripts/fetch_kv_heads.py --missing\n\n    # Update all cards (overwrite existing values)\n    uv run python scripts/fetch_kv_heads.py --all\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport json\nimport sys\nimport urllib.request\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\nfrom pathlib import Path\n\nimport tomlkit\n\nCARDS_DIR = (\n    Path(__file__).resolve().parent.parent / \"resources\" / \"inference_model_cards\"\n)\nMAX_WORKERS = 5\n\n\ndef fetch_kv_heads(model_id: str) -> int | None:\n    \"\"\"Fetch num_key_value_heads from HuggingFace config.json.\"\"\"\n    url = f\"https://huggingface.co/{model_id}/raw/main/config.json\"\n    try:\n        with urllib.request.urlopen(url, timeout=15) as resp:\n            config = json.loads(resp.read())\n    except Exception as e:\n        print(f\"  ERROR fetching {url}: {e}\", file=sys.stderr)\n        return None\n\n    for source in [config, config.get(\"text_config\", {})]:\n        if \"num_key_value_heads\" in source:\n            return int(source[\"num_key_value_heads\"])\n\n    return None\n\n\ndef update_toml(path: Path, kv_heads: int) -> bool:\n    \"\"\"Insert or update num_key_value_heads in a TOML file. Returns True if changed.\"\"\"\n    content = path.read_text()\n    doc = tomlkit.parse(content)\n\n    if doc.get(\"num_key_value_heads\") == kv_heads:\n        return False\n\n    # Insert after hidden_size if adding for the first time\n    if \"num_key_value_heads\" not in doc:\n        new_doc = tomlkit.document()\n        for key, value in doc.items():\n            new_doc[key] = value\n            if key == \"hidden_size\":\n                new_doc[\"num_key_value_heads\"] = kv_heads\n        path.write_text(tomlkit.dumps(new_doc))\n    else:\n        doc[\"num_key_value_heads\"] = kv_heads\n        path.write_text(tomlkit.dumps(doc))\n\n    return True\n\n\ndef process_card(path: Path) -> tuple[str, str]:\n    \"\"\"Fetch and update a single card. Returns (filename, status).\"\"\"\n    content = path.read_text()\n    doc = tomlkit.parse(content)\n    model_id = doc.get(\"model_id\")\n    if not model_id:\n        return path.name, \"SKIP (no model_id)\"\n\n    kv_heads = fetch_kv_heads(str(model_id))\n    if kv_heads is None:\n        return path.name, \"FAILED\"\n\n    changed = update_toml(path, kv_heads)\n    return path.name, f\"{kv_heads} ({'UPDATED' if changed else 'UNCHANGED'})\"\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Fetch num_key_value_heads from HuggingFace and update TOML cards.\"\n    )\n    group = parser.add_mutually_exclusive_group(required=True)\n    group.add_argument(\n        \"--all\",\n        action=\"store_true\",\n        help=\"Update all model cards (overwrite existing values)\",\n    )\n    group.add_argument(\n        \"--missing\",\n        action=\"store_true\",\n        help=\"Only update cards missing num_key_value_heads\",\n    )\n    args = parser.parse_args()\n\n    toml_files = sorted(CARDS_DIR.glob(\"*.toml\"))\n    if not toml_files:\n        print(f\"No TOML files found in {CARDS_DIR}\", file=sys.stderr)\n        sys.exit(1)\n\n    to_process = []\n    skipped = 0\n\n    for path in toml_files:\n        if args.missing and \"num_key_value_heads\" in path.read_text():\n            skipped += 1\n            continue\n        to_process.append(path)\n\n    updated = 0\n    failed = 0\n\n    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as pool:\n        futures = {pool.submit(process_card, path): path for path in to_process}\n        for future in as_completed(futures):\n            name, status = future.result()\n            print(f\"  {name}: {status}\")\n            if \"UPDATED\" in status:\n                updated += 1\n            elif \"FAILED\" in status:\n                failed += 1\n\n    print(f\"\\nDone: {updated} updated, {skipped} skipped, {failed} failed\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/exo/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/__main__.py",
    "content": "from __future__ import annotations\n\nimport sys\nfrom collections.abc import Sequence\nfrom multiprocessing import freeze_support\nfrom typing import Final\n\nfrom exo.main import main\n\nINLINE_CODE_FLAG: Final[str] = \"-c\"\n\n\ndef _maybe_run_inline_code(argv: Sequence[str]) -> bool:\n    \"\"\"\n    Reproduce the bare minimum of Python's `-c` flag so multiprocessing\n    helper processes (for example the resource tracker) can execute.\n    \"\"\"\n\n    try:\n        flag_index = argv.index(INLINE_CODE_FLAG)\n    except ValueError:\n        return False\n\n    code_index = flag_index + 1\n    if code_index >= len(argv):\n        return False\n\n    inline_code = argv[code_index]\n    sys.argv = [\"-c\", *argv[code_index + 1 :]]\n    namespace: dict[str, object] = {\"__name__\": \"__main__\"}\n    exec(inline_code, namespace, namespace)\n    return True\n\n\nif __name__ == \"__main__\":\n    if _maybe_run_inline_code(sys.argv):\n        sys.exit(0)\n    freeze_support()\n    main()\n"
  },
  {
    "path": "src/exo/api/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/api/adapters/__init__.py",
    "content": "\"\"\"API adapters for different API formats (Claude, OpenAI Responses, etc.).\"\"\"\n"
  },
  {
    "path": "src/exo/api/adapters/chat_completions.py",
    "content": "\"\"\"OpenAI Chat Completions API adapter for converting requests/responses.\"\"\"\n\nimport time\nfrom collections.abc import AsyncGenerator\nfrom typing import Any\n\nfrom exo.api.types import (\n    ChatCompletionChoice,\n    ChatCompletionMessage,\n    ChatCompletionMessageText,\n    ChatCompletionRequest,\n    ChatCompletionResponse,\n    ErrorInfo,\n    ErrorResponse,\n    FinishReason,\n    Logprobs,\n    LogprobsContentItem,\n    StreamingChoiceResponse,\n    ToolCall,\n    Usage,\n)\nfrom exo.shared.types.chunks import (\n    ErrorChunk,\n    PrefillProgressChunk,\n    TokenChunk,\n    ToolCallChunk,\n)\nfrom exo.shared.types.common import CommandId\nfrom exo.shared.types.text_generation import (\n    InputMessage,\n    TextGenerationTaskParams,\n    resolve_reasoning_params,\n)\n\n\ndef chat_request_to_text_generation(\n    request: ChatCompletionRequest,\n) -> TextGenerationTaskParams:\n    instructions: str | None = None\n    input_messages: list[InputMessage] = []\n    chat_template_messages: list[dict[str, Any]] = []\n\n    for msg in request.messages:\n        # Normalize content to string\n        content: str\n        if msg.content is None:\n            content = \"\"\n        elif isinstance(msg.content, str):\n            content = msg.content\n        elif isinstance(msg.content, ChatCompletionMessageText):\n            content = msg.content.text\n        else:\n            # List of ChatCompletionMessageText\n            content = \"\\n\".join(item.text for item in msg.content)\n\n        # Extract system message as instructions\n        if msg.role == \"system\":\n            if instructions is None:\n                instructions = content\n            else:\n                # Append additional system messages\n                instructions = f\"{instructions}\\n{content}\"\n            chat_template_messages.append({\"role\": \"system\", \"content\": content})\n        else:\n            # Skip messages with no meaningful content\n            if (\n                msg.content is None\n                and msg.reasoning_content is None\n                and msg.tool_calls is None\n            ):\n                continue\n\n            if msg.role in (\"user\", \"assistant\", \"developer\"):\n                input_messages.append(InputMessage(role=msg.role, content=content))\n\n            # Build full message dict for chat template (preserves tool_calls etc.)\n            # Normalize content for model_dump\n            msg_copy = msg.model_copy(update={\"content\": content})\n            dumped: dict[str, Any] = msg_copy.model_dump(exclude_none=True)\n            chat_template_messages.append(dumped)\n\n    resolved_effort, resolved_thinking = resolve_reasoning_params(\n        request.reasoning_effort, request.enable_thinking\n    )\n\n    return TextGenerationTaskParams(\n        model=request.model,\n        input=input_messages\n        if input_messages\n        else [InputMessage(role=\"user\", content=\"\")],\n        instructions=instructions,\n        max_output_tokens=request.max_tokens,\n        temperature=request.temperature,\n        top_p=request.top_p,\n        top_k=request.top_k,\n        stop=request.stop,\n        seed=request.seed,\n        stream=request.stream,\n        tools=request.tools,\n        reasoning_effort=resolved_effort,\n        enable_thinking=resolved_thinking,\n        chat_template_messages=chat_template_messages\n        if chat_template_messages\n        else None,\n        logprobs=request.logprobs or False,\n        top_logprobs=request.top_logprobs,\n        min_p=request.min_p,\n        repetition_penalty=request.repetition_penalty,\n        repetition_context_size=request.repetition_context_size,\n    )\n\n\ndef chunk_to_response(\n    chunk: TokenChunk, command_id: CommandId\n) -> ChatCompletionResponse:\n    \"\"\"Convert a TokenChunk to a streaming ChatCompletionResponse.\"\"\"\n    # Build logprobs if available\n    logprobs: Logprobs | None = None\n    if chunk.logprob is not None:\n        logprobs = Logprobs(\n            content=[\n                LogprobsContentItem(\n                    token=chunk.text,\n                    logprob=chunk.logprob,\n                    top_logprobs=chunk.top_logprobs or [],\n                )\n            ]\n        )\n\n    if chunk.is_thinking:\n        delta = ChatCompletionMessage(role=\"assistant\", reasoning_content=chunk.text)\n    else:\n        delta = ChatCompletionMessage(role=\"assistant\", content=chunk.text)\n\n    return ChatCompletionResponse(\n        id=command_id,\n        created=int(time.time()),\n        model=chunk.model,\n        choices=[\n            StreamingChoiceResponse(\n                index=0,\n                delta=delta,\n                logprobs=logprobs,\n                finish_reason=chunk.finish_reason,\n            )\n        ],\n    )\n\n\nasync def generate_chat_stream(\n    command_id: CommandId,\n    chunk_stream: AsyncGenerator[\n        PrefillProgressChunk | ErrorChunk | ToolCallChunk | TokenChunk, None\n    ],\n) -> AsyncGenerator[str, None]:\n    \"\"\"Generate Chat Completions API streaming events from chunks.\"\"\"\n    last_usage: Usage | None = None\n\n    async for chunk in chunk_stream:\n        match chunk:\n            case PrefillProgressChunk():\n                # Use SSE comment so third-party clients ignore it\n                yield f\": prefill_progress {chunk.model_dump_json()}\\n\\n\"\n\n            case ErrorChunk():\n                error_response = ErrorResponse(\n                    error=ErrorInfo(\n                        message=chunk.error_message or \"Internal server error\",\n                        type=\"InternalServerError\",\n                        code=500,\n                    )\n                )\n                yield f\"data: {error_response.model_dump_json()}\\n\\n\"\n                yield \"data: [DONE]\\n\\n\"\n                return\n\n            case ToolCallChunk():\n                last_usage = chunk.usage or last_usage\n\n                tool_call_deltas = [\n                    ToolCall(\n                        id=tool.id,\n                        index=i,\n                        function=tool,\n                    )\n                    for i, tool in enumerate(chunk.tool_calls)\n                ]\n                tool_response = ChatCompletionResponse(\n                    id=command_id,\n                    created=int(time.time()),\n                    model=chunk.model,\n                    choices=[\n                        StreamingChoiceResponse(\n                            index=0,\n                            delta=ChatCompletionMessage(\n                                role=\"assistant\",\n                                tool_calls=tool_call_deltas,\n                            ),\n                            finish_reason=\"tool_calls\",\n                        )\n                    ],\n                    usage=last_usage,\n                )\n                yield f\"data: {tool_response.model_dump_json()}\\n\\n\"\n                yield \"data: [DONE]\\n\\n\"\n                return\n\n            case TokenChunk():\n                last_usage = chunk.usage or last_usage\n\n                chunk_response = chunk_to_response(chunk, command_id)\n                if chunk.finish_reason is not None:\n                    chunk_response = chunk_response.model_copy(\n                        update={\"usage\": last_usage}\n                    )\n                yield f\"data: {chunk_response.model_dump_json()}\\n\\n\"\n\n                if chunk.finish_reason is not None:\n                    yield \"data: [DONE]\\n\\n\"\n\n\nasync def collect_chat_response(\n    command_id: CommandId,\n    chunk_stream: AsyncGenerator[\n        ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None\n    ],\n) -> AsyncGenerator[str]:\n    # This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because\n    # FastAPI handles the cancellation better but wouldn't auto-serialize for some reason\n    \"\"\"Collect all token chunks and return a single ChatCompletionResponse.\"\"\"\n    text_parts: list[str] = []\n    thinking_parts: list[str] = []\n    tool_calls: list[ToolCall] = []\n    logprobs_content: list[LogprobsContentItem] = []\n    model: str | None = None\n    finish_reason: FinishReason | None = None\n    error_message: str | None = None\n    last_usage: Usage | None = None\n\n    async for chunk in chunk_stream:\n        match chunk:\n            case PrefillProgressChunk():\n                continue\n\n            case ErrorChunk():\n                error_message = chunk.error_message or \"Internal server error\"\n                break\n\n            case TokenChunk():\n                if model is None:\n                    model = chunk.model\n                last_usage = chunk.usage or last_usage\n                if chunk.is_thinking:\n                    thinking_parts.append(chunk.text)\n                else:\n                    text_parts.append(chunk.text)\n                if chunk.logprob is not None:\n                    logprobs_content.append(\n                        LogprobsContentItem(\n                            token=chunk.text,\n                            logprob=chunk.logprob,\n                            top_logprobs=chunk.top_logprobs or [],\n                        )\n                    )\n                if chunk.finish_reason is not None:\n                    finish_reason = chunk.finish_reason\n\n            case ToolCallChunk():\n                if model is None:\n                    model = chunk.model\n                last_usage = chunk.usage or last_usage\n                tool_calls.extend(\n                    ToolCall(\n                        id=tool.id,\n                        index=i,\n                        function=tool,\n                    )\n                    for i, tool in enumerate(chunk.tool_calls)\n                )\n                finish_reason = chunk.finish_reason\n\n    if error_message is not None:\n        raise ValueError(error_message)\n\n    combined_text = \"\".join(text_parts)\n    combined_thinking = \"\".join(thinking_parts) if thinking_parts else None\n    assert model is not None\n\n    yield ChatCompletionResponse(\n        id=command_id,\n        created=int(time.time()),\n        model=model,\n        choices=[\n            ChatCompletionChoice(\n                index=0,\n                message=ChatCompletionMessage(\n                    role=\"assistant\",\n                    content=combined_text,\n                    reasoning_content=combined_thinking,\n                    tool_calls=tool_calls if tool_calls else None,\n                ),\n                logprobs=Logprobs(content=logprobs_content)\n                if logprobs_content\n                else None,\n                finish_reason=finish_reason,\n            )\n        ],\n        usage=last_usage,\n    ).model_dump_json()\n    return\n"
  },
  {
    "path": "src/exo/api/adapters/claude.py",
    "content": "\"\"\"Claude Messages API adapter for converting requests/responses.\"\"\"\n\nimport json\nimport re\nfrom collections.abc import AsyncGenerator\nfrom typing import Any\n\nfrom exo.api.types import FinishReason, Usage\nfrom exo.api.types.claude_api import (\n    ClaudeContentBlock,\n    ClaudeContentBlockDeltaEvent,\n    ClaudeContentBlockStartEvent,\n    ClaudeContentBlockStopEvent,\n    ClaudeInputJsonDelta,\n    ClaudeMessageDelta,\n    ClaudeMessageDeltaEvent,\n    ClaudeMessageDeltaUsage,\n    ClaudeMessagesRequest,\n    ClaudeMessagesResponse,\n    ClaudeMessageStart,\n    ClaudeMessageStartEvent,\n    ClaudeMessageStopEvent,\n    ClaudeStopReason,\n    ClaudeTextBlock,\n    ClaudeTextDelta,\n    ClaudeThinkingBlock,\n    ClaudeThinkingDelta,\n    ClaudeToolResultBlock,\n    ClaudeToolUseBlock,\n    ClaudeUsage,\n)\nfrom exo.shared.types.chunks import (\n    ErrorChunk,\n    PrefillProgressChunk,\n    TokenChunk,\n    ToolCallChunk,\n)\nfrom exo.shared.types.common import CommandId\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\n\n\ndef finish_reason_to_claude_stop_reason(\n    finish_reason: FinishReason | None,\n) -> ClaudeStopReason | None:\n    \"\"\"Map OpenAI finish_reason to Claude stop_reason.\"\"\"\n    if finish_reason is None:\n        return None\n    mapping: dict[FinishReason, ClaudeStopReason] = {\n        \"stop\": \"end_turn\",\n        \"length\": \"max_tokens\",\n        \"tool_calls\": \"tool_use\",\n        \"content_filter\": \"end_turn\",\n        \"function_call\": \"tool_use\",\n    }\n    return mapping.get(finish_reason, \"end_turn\")\n\n\ndef _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:\n    \"\"\"Extract plain text from a tool_result content field.\"\"\"\n    if block.content is None:\n        return \"\"\n    if isinstance(block.content, str):\n        return block.content\n    return \"\".join(sub_block.text for sub_block in block.content)\n\n\n# Matches \"x-anthropic-billing-header: ...;\" (with optional trailing newline)\n# or similar telemetry headers that change every request and break KV prefix caching.\n_VOLATILE_HEADER_RE = re.compile(r\"^x-anthropic-[^\\n]*;\\n?\", re.MULTILINE)\n\n\ndef _strip_volatile_headers(text: str) -> str:\n    \"\"\"Remove Anthropic billing/telemetry headers from system prompt text.\n\n    Claude Code prepends headers like 'x-anthropic-billing-header: cc_version=...;\n    cc_entrypoint=...; cch=...;' that contain per-request content hashes. These\n    change every request and break KV prefix caching (the prefix diverges at ~20\n    tokens instead of matching thousands of conversation tokens).\n    \"\"\"\n    return _VOLATILE_HEADER_RE.sub(\"\", text)\n\n\ndef claude_request_to_text_generation(\n    request: ClaudeMessagesRequest,\n) -> TextGenerationTaskParams:\n    # Handle system message\n    instructions: str | None = None\n    chat_template_messages: list[dict[str, Any]] = []\n\n    if request.system:\n        if isinstance(request.system, str):\n            instructions = request.system\n        else:\n            instructions = \"\".join(block.text for block in request.system)\n\n        instructions = _strip_volatile_headers(instructions)\n        chat_template_messages.append({\"role\": \"system\", \"content\": instructions})\n\n    # Convert messages to input\n    input_messages: list[InputMessage] = []\n    for msg in request.messages:\n        if isinstance(msg.content, str):\n            input_messages.append(InputMessage(role=msg.role, content=msg.content))\n            chat_template_messages.append({\"role\": msg.role, \"content\": msg.content})\n            continue\n\n        # Process structured content blocks\n        text_parts: list[str] = []\n        thinking_parts: list[str] = []\n        tool_calls: list[dict[str, Any]] = []\n        tool_results: list[ClaudeToolResultBlock] = []\n\n        for block in msg.content:\n            if isinstance(block, ClaudeTextBlock):\n                text_parts.append(block.text)\n            elif isinstance(block, ClaudeThinkingBlock):\n                thinking_parts.append(block.thinking)\n            elif isinstance(block, ClaudeToolUseBlock):\n                tool_calls.append(\n                    {\n                        \"id\": block.id,\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": block.name,\n                            \"arguments\": json.dumps(block.input),\n                        },\n                    }\n                )\n            elif isinstance(block, ClaudeToolResultBlock):\n                tool_results.append(block)\n\n        content = \"\".join(text_parts)\n        reasoning_content = \"\".join(thinking_parts) if thinking_parts else None\n\n        # Build InputMessage from text content\n        if msg.role in (\"user\", \"assistant\"):\n            input_messages.append(InputMessage(role=msg.role, content=content))\n\n        # Build chat_template_messages preserving tool structure\n        if tool_calls:\n            chat_msg: dict[str, Any] = {\n                \"role\": \"assistant\",\n                \"content\": content,\n                \"tool_calls\": tool_calls,\n            }\n            if reasoning_content:\n                chat_msg[\"reasoning_content\"] = reasoning_content\n            chat_template_messages.append(chat_msg)\n        elif tool_results:\n            for tr in tool_results:\n                chat_template_messages.append(\n                    {\n                        \"role\": \"tool\",\n                        \"tool_call_id\": tr.tool_use_id,\n                        \"content\": _extract_tool_result_text(tr),\n                    }\n                )\n        else:\n            chat_msg = {\"role\": msg.role, \"content\": content}\n            if reasoning_content:\n                chat_msg[\"reasoning_content\"] = reasoning_content\n            chat_template_messages.append(chat_msg)\n\n    # Convert Claude tool definitions to OpenAI-style function tools\n    tools: list[dict[str, Any]] | None = None\n    if request.tools:\n        tools = [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": tool.name,\n                    \"description\": tool.description or \"\",\n                    \"parameters\": tool.input_schema,\n                },\n            }\n            for tool in request.tools\n        ]\n\n    enable_thinking: bool | None = None\n    if request.thinking is not None:\n        enable_thinking = request.thinking.type in (\"enabled\", \"adaptive\")\n\n    return TextGenerationTaskParams(\n        model=request.model,\n        input=input_messages\n        if input_messages\n        else [InputMessage(role=\"user\", content=\"\")],\n        instructions=instructions,\n        max_output_tokens=request.max_tokens,\n        temperature=request.temperature,\n        top_p=request.top_p,\n        top_k=request.top_k,\n        stop=request.stop_sequences,\n        stream=request.stream,\n        tools=tools,\n        enable_thinking=enable_thinking,\n        chat_template_messages=chat_template_messages\n        if chat_template_messages\n        else None,\n    )\n\n\nasync def collect_claude_response(\n    command_id: CommandId,\n    model: str,\n    chunk_stream: AsyncGenerator[\n        ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None\n    ],\n) -> AsyncGenerator[str]:\n    # This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because\n    # FastAPI handles the cancellation better but wouldn't auto-serialize for some reason\n    \"\"\"Collect all token chunks and return a single ClaudeMessagesResponse.\"\"\"\n    text_parts: list[str] = []\n    thinking_parts: list[str] = []\n    tool_use_blocks: list[ClaudeToolUseBlock] = []\n    stop_reason: ClaudeStopReason | None = None\n    last_usage: Usage | None = None\n    error_message: str | None = None\n\n    async for chunk in chunk_stream:\n        if isinstance(chunk, PrefillProgressChunk):\n            continue\n\n        if isinstance(chunk, ErrorChunk):\n            error_message = chunk.error_message or \"Internal server error\"\n            break\n\n        last_usage = chunk.usage or last_usage\n\n        if isinstance(chunk, ToolCallChunk):\n            for tool in chunk.tool_calls:\n                tool_use_blocks.append(\n                    ClaudeToolUseBlock(\n                        id=f\"toolu_{tool.id}\",\n                        name=tool.name,\n                        input=json.loads(tool.arguments),  # pyright: ignore[reportAny]\n                    )\n                )\n            stop_reason = \"tool_use\"\n            continue\n\n        if chunk.is_thinking:\n            thinking_parts.append(chunk.text)\n        else:\n            text_parts.append(chunk.text)\n\n        if chunk.finish_reason is not None:\n            stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)\n\n    if error_message is not None:\n        raise ValueError(error_message)\n\n    combined_text = \"\".join(text_parts)\n    combined_thinking = \"\".join(thinking_parts)\n\n    # Build content blocks\n    content: list[ClaudeContentBlock] = []\n    if combined_thinking:\n        content.append(ClaudeThinkingBlock(thinking=combined_thinking))\n    if combined_text:\n        content.append(ClaudeTextBlock(text=combined_text))\n    content.extend(tool_use_blocks)\n\n    # If no content at all, include empty text block\n    if not content:\n        content.append(ClaudeTextBlock(text=\"\"))\n\n    # Use actual usage data if available\n    input_tokens = last_usage.prompt_tokens if last_usage else 0\n    output_tokens = last_usage.completion_tokens if last_usage else 0\n\n    yield ClaudeMessagesResponse(\n        id=f\"msg_{command_id}\",\n        model=model,\n        content=content,\n        stop_reason=stop_reason,\n        usage=ClaudeUsage(\n            input_tokens=input_tokens,\n            output_tokens=output_tokens,\n        ),\n    ).model_dump_json()\n    return\n\n\nasync def generate_claude_stream(\n    command_id: CommandId,\n    model: str,\n    chunk_stream: AsyncGenerator[\n        ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None\n    ],\n) -> AsyncGenerator[str, None]:\n    \"\"\"Generate Claude Messages API streaming events from TokenChunks.\"\"\"\n    # Initial message_start event\n    initial_message = ClaudeMessageStart(\n        id=f\"msg_{command_id}\",\n        model=model,\n        content=[],\n        stop_reason=None,\n        usage=ClaudeUsage(input_tokens=0, output_tokens=0),\n    )\n    start_event = ClaudeMessageStartEvent(message=initial_message)\n    yield f\"event: message_start\\ndata: {start_event.model_dump_json()}\\n\\n\"\n\n    output_tokens = 0\n    stop_reason: ClaudeStopReason | None = None\n    last_usage: Usage | None = None\n    next_block_index = 0\n\n    # Track whether we've started thinking/text blocks\n    thinking_block_started = False\n    thinking_block_index = -1\n    text_block_started = False\n    text_block_index = -1\n\n    async for chunk in chunk_stream:\n        if isinstance(chunk, PrefillProgressChunk):\n            continue\n\n        if isinstance(chunk, ErrorChunk):\n            # Close text block and bail\n            break\n\n        last_usage = chunk.usage or last_usage\n\n        if isinstance(chunk, ToolCallChunk):\n            stop_reason = \"tool_use\"\n\n            # Emit tool_use content blocks\n            for tool in chunk.tool_calls:\n                tool_id = f\"toolu_{tool.id}\"\n                tool_input_json = tool.arguments\n\n                # content_block_start for tool_use\n                tool_block_start = ClaudeContentBlockStartEvent(\n                    index=next_block_index,\n                    content_block=ClaudeToolUseBlock(\n                        id=tool_id, name=tool.name, input={}\n                    ),\n                )\n                yield f\"event: content_block_start\\ndata: {tool_block_start.model_dump_json()}\\n\\n\"\n\n                # content_block_delta with input_json_delta\n                tool_delta_event = ClaudeContentBlockDeltaEvent(\n                    index=next_block_index,\n                    delta=ClaudeInputJsonDelta(partial_json=tool_input_json),\n                )\n                yield f\"event: content_block_delta\\ndata: {tool_delta_event.model_dump_json()}\\n\\n\"\n\n                # content_block_stop\n                tool_block_stop = ClaudeContentBlockStopEvent(index=next_block_index)\n                yield f\"event: content_block_stop\\ndata: {tool_block_stop.model_dump_json()}\\n\\n\"\n\n                next_block_index += 1\n            continue\n\n        output_tokens += 1  # Count each chunk as one token\n\n        if chunk.is_thinking:\n            # Start thinking block on first thinking token\n            if not thinking_block_started:\n                thinking_block_started = True\n                thinking_block_index = next_block_index\n                next_block_index += 1\n                block_start = ClaudeContentBlockStartEvent(\n                    index=thinking_block_index,\n                    content_block=ClaudeThinkingBlock(thinking=\"\"),\n                )\n                yield f\"event: content_block_start\\ndata: {block_start.model_dump_json()}\\n\\n\"\n\n            delta_event = ClaudeContentBlockDeltaEvent(\n                index=thinking_block_index,\n                delta=ClaudeThinkingDelta(thinking=chunk.text),\n            )\n            yield f\"event: content_block_delta\\ndata: {delta_event.model_dump_json()}\\n\\n\"\n        else:\n            # Close thinking block when transitioning to text\n            if thinking_block_started and text_block_index == -1:\n                block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)\n                yield f\"event: content_block_stop\\ndata: {block_stop.model_dump_json()}\\n\\n\"\n\n            # Start text block on first text token\n            if not text_block_started:\n                text_block_started = True\n                text_block_index = next_block_index\n                next_block_index += 1\n                block_start = ClaudeContentBlockStartEvent(\n                    index=text_block_index,\n                    content_block=ClaudeTextBlock(text=\"\"),\n                )\n                yield f\"event: content_block_start\\ndata: {block_start.model_dump_json()}\\n\\n\"\n\n            delta_event = ClaudeContentBlockDeltaEvent(\n                index=text_block_index,\n                delta=ClaudeTextDelta(text=chunk.text),\n            )\n            yield f\"event: content_block_delta\\ndata: {delta_event.model_dump_json()}\\n\\n\"\n\n        if chunk.finish_reason is not None:\n            stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)\n\n    # Use actual token count from usage if available\n    if last_usage is not None:\n        output_tokens = last_usage.completion_tokens\n\n    # Close any open blocks\n    if thinking_block_started and text_block_index == -1:\n        block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)\n        yield f\"event: content_block_stop\\ndata: {block_stop.model_dump_json()}\\n\\n\"\n\n    if text_block_started:\n        block_stop = ClaudeContentBlockStopEvent(index=text_block_index)\n        yield f\"event: content_block_stop\\ndata: {block_stop.model_dump_json()}\\n\\n\"\n\n    if not thinking_block_started and not text_block_started:\n        empty_start = ClaudeContentBlockStartEvent(\n            index=0, content_block=ClaudeTextBlock(text=\"\")\n        )\n        yield f\"event: content_block_start\\ndata: {empty_start.model_dump_json()}\\n\\n\"\n        empty_stop = ClaudeContentBlockStopEvent(index=0)\n        yield f\"event: content_block_stop\\ndata: {empty_stop.model_dump_json()}\\n\\n\"\n\n    # message_delta\n    message_delta = ClaudeMessageDeltaEvent(\n        delta=ClaudeMessageDelta(stop_reason=stop_reason),\n        usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),\n    )\n    yield f\"event: message_delta\\ndata: {message_delta.model_dump_json()}\\n\\n\"\n\n    # message_stop\n    message_stop = ClaudeMessageStopEvent()\n    yield f\"event: message_stop\\ndata: {message_stop.model_dump_json()}\\n\\n\"\n"
  },
  {
    "path": "src/exo/api/adapters/ollama.py",
    "content": "from __future__ import annotations\n\nimport json\nfrom collections.abc import AsyncGenerator\nfrom typing import Any\n\nfrom exo.api.types.ollama_api import (\n    OllamaChatRequest,\n    OllamaChatResponse,\n    OllamaDoneReason,\n    OllamaGenerateRequest,\n    OllamaGenerateResponse,\n    OllamaMessage,\n    OllamaToolCall,\n    OllamaToolFunction,\n)\nfrom exo.shared.types.chunks import (\n    ErrorChunk,\n    PrefillProgressChunk,\n    TokenChunk,\n    ToolCallChunk,\n)\nfrom exo.shared.types.common import CommandId\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\n\n\ndef _map_done_reason(\n    finish_reason: str | None,\n) -> OllamaDoneReason | None:\n    if finish_reason is None:\n        return None\n    if finish_reason == \"stop\":\n        return \"stop\"\n    if finish_reason == \"length\":\n        return \"length\"\n    if finish_reason in (\"tool_calls\", \"function_call\"):\n        return \"tool_call\"\n    if finish_reason == \"error\":\n        return \"error\"\n    return \"stop\"\n\n\ndef _try_parse_json(value: str) -> dict[str, Any] | str:\n    try:\n        return json.loads(value)  # type: ignore\n    except json.JSONDecodeError:\n        return value\n\n\ndef _build_tool_calls(chunk: ToolCallChunk) -> list[OllamaToolCall]:\n    tool_calls: list[OllamaToolCall] = []\n    for index, tool in enumerate(chunk.tool_calls):\n        # tool.arguments is always str; try to parse as JSON dict for Ollama format\n        arguments: dict[str, Any] | str = _try_parse_json(tool.arguments)\n        tool_calls.append(\n            OllamaToolCall(\n                id=tool.id,\n                type=\"function\",\n                function=OllamaToolFunction(\n                    name=tool.name, arguments=arguments, index=index\n                ),\n            )\n        )\n    return tool_calls\n\n\ndef _get_usage(\n    chunk: TokenChunk | ToolCallChunk,\n) -> tuple[int | None, int | None]:\n    \"\"\"Extract (prompt_eval_count, eval_count) from a chunk.\"\"\"\n    if chunk.usage is not None:\n        return (chunk.usage.prompt_tokens, chunk.usage.completion_tokens)\n    if chunk.stats is not None:\n        return (chunk.stats.prompt_tokens, chunk.stats.generation_tokens)\n    return (None, None)\n\n\ndef ollama_request_to_text_generation(\n    request: OllamaChatRequest,\n) -> TextGenerationTaskParams:\n    \"\"\"Convert Ollama chat request to exo's internal text generation format.\"\"\"\n    instructions: str | None = None\n    input_messages: list[InputMessage] = []\n    chat_template_messages: list[dict[str, Any]] = []\n    tool_message_index = 0\n\n    for msg in request.messages:\n        content = msg.content or \"\"\n\n        if msg.role == \"system\":\n            if instructions is None:\n                instructions = content\n            else:\n                instructions = f\"{instructions}\\n{content}\"\n            chat_template_messages.append({\"role\": \"system\", \"content\": content})\n            continue\n\n        if msg.role in (\"user\", \"assistant\") and (\n            msg.content is not None or msg.thinking is not None or msg.tool_calls\n        ):\n            input_messages.append(InputMessage(role=msg.role, content=content))\n\n        dumped: dict[str, Any] = {\"role\": msg.role, \"content\": content}\n        if msg.thinking is not None:\n            dumped[\"thinking\"] = msg.thinking\n        if msg.tool_calls is not None:\n            tool_calls_list: list[dict[str, Any]] = []\n            for tc in msg.tool_calls:\n                function: dict[str, Any] = {\n                    \"name\": tc.function.name,\n                    \"arguments\": (\n                        json.dumps(tc.function.arguments)\n                        if isinstance(tc.function.arguments, dict)\n                        else tc.function.arguments\n                    ),\n                }\n                if tc.function.index is not None:\n                    function[\"index\"] = tc.function.index\n                tool_call: dict[str, Any] = {\"function\": function}\n                if tc.id is not None:\n                    tool_call[\"id\"] = tc.id\n                if tc.type is not None:\n                    tool_call[\"type\"] = tc.type\n                tool_calls_list.append(tool_call)\n            dumped[\"tool_calls\"] = tool_calls_list\n        if msg.name is not None:\n            dumped[\"name\"] = msg.name\n        if msg.role == \"tool\":\n            tool_message_index += 1\n            tool_call_id = msg.tool_name or msg.name or f\"tool_{tool_message_index}\"\n            dumped[\"tool_call_id\"] = tool_call_id\n            if msg.tool_name is not None:\n                dumped[\"tool_name\"] = msg.tool_name\n        chat_template_messages.append(dumped)\n\n    options = request.options\n    return TextGenerationTaskParams(\n        model=request.model,\n        input=input_messages\n        if input_messages\n        else [InputMessage(role=\"user\", content=\"\")],\n        instructions=instructions,\n        max_output_tokens=options.num_predict if options else None,\n        temperature=options.temperature if options else None,\n        top_p=options.top_p if options else None,\n        top_k=options.top_k if options else None,\n        stop=options.stop if options else None,\n        seed=options.seed if options else None,\n        stream=request.stream,\n        tools=request.tools,\n        enable_thinking=request.think,\n        chat_template_messages=chat_template_messages\n        if chat_template_messages\n        else None,\n    )\n\n\nasync def generate_ollama_chat_stream(\n    _command_id: CommandId,\n    chunk_stream: AsyncGenerator[\n        ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None\n    ],\n) -> AsyncGenerator[str, None]:\n    \"\"\"Generate streaming responses in Ollama format (newline-delimited JSON).\"\"\"\n    thinking_parts: list[str] = []\n\n    async for chunk in chunk_stream:\n        match chunk:\n            case PrefillProgressChunk():\n                continue\n\n            case ErrorChunk():\n                error_response = OllamaChatResponse(\n                    model=str(chunk.model),\n                    message=OllamaMessage(\n                        role=\"assistant\", content=chunk.error_message\n                    ),\n                    done=True,\n                    done_reason=\"error\",\n                )\n                yield f\"{error_response.model_dump_json(exclude_none=True)}\\n\"\n                return\n\n            case ToolCallChunk():\n                prompt_eval, eval_count = _get_usage(chunk)\n                response = OllamaChatResponse(\n                    model=str(chunk.model),\n                    message=OllamaMessage(\n                        role=\"assistant\",\n                        content=\"\",\n                        tool_calls=_build_tool_calls(chunk),\n                        thinking=\"\".join(thinking_parts) if thinking_parts else None,\n                    ),\n                    done=True,\n                    done_reason=\"tool_call\",\n                    prompt_eval_count=prompt_eval,\n                    eval_count=eval_count,\n                )\n                yield f\"{response.model_dump_json(exclude_none=True)}\\n\"\n                return\n\n            case TokenChunk():\n                done = chunk.finish_reason is not None\n\n                if chunk.is_thinking:\n                    thinking_parts.append(chunk.text)\n                    response = OllamaChatResponse(\n                        model=str(chunk.model),\n                        message=OllamaMessage(\n                            role=\"assistant\", content=\"\", thinking=chunk.text\n                        ),\n                        done=False,\n                    )\n                    yield f\"{response.model_dump_json(exclude_none=True)}\\n\"\n                elif done:\n                    prompt_eval, eval_count = _get_usage(chunk)\n                    response = OllamaChatResponse(\n                        model=str(chunk.model),\n                        message=OllamaMessage(\n                            role=\"assistant\",\n                            content=chunk.text,\n                        ),\n                        done=True,\n                        done_reason=_map_done_reason(chunk.finish_reason),\n                        prompt_eval_count=prompt_eval,\n                        eval_count=eval_count,\n                    )\n                    yield f\"{response.model_dump_json(exclude_none=True)}\\n\"\n                else:\n                    response = OllamaChatResponse(\n                        model=str(chunk.model),\n                        message=OllamaMessage(role=\"assistant\", content=chunk.text),\n                        done=False,\n                    )\n                    yield f\"{response.model_dump_json(exclude_none=True)}\\n\"\n\n                if done:\n                    return\n\n\nasync def collect_ollama_chat_response(\n    _command_id: CommandId,\n    chunk_stream: AsyncGenerator[\n        ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None\n    ],\n) -> AsyncGenerator[str]:\n    \"\"\"Collect streaming chunks into a single non-streaming Ollama response.\n\n    Returns an AsyncGenerator[str] (single yield) for consistency with FastAPI\n    StreamingResponse cancellation handling.\n    \"\"\"\n    text_parts: list[str] = []\n    thinking_parts: list[str] = []\n    tool_calls: list[OllamaToolCall] = []\n    model: str | None = None\n    finish_reason: str | None = None\n    prompt_eval_count: int | None = None\n    eval_count: int | None = None\n\n    async for chunk in chunk_stream:\n        match chunk:\n            case PrefillProgressChunk():\n                continue\n\n            case ErrorChunk():\n                raise ValueError(chunk.error_message or \"Internal server error\")\n\n            case TokenChunk():\n                if model is None:\n                    model = str(chunk.model)\n                if chunk.is_thinking:\n                    thinking_parts.append(chunk.text)\n                else:\n                    text_parts.append(chunk.text)\n                if chunk.finish_reason is not None:\n                    finish_reason = chunk.finish_reason\n                    prompt_eval_count, eval_count = _get_usage(chunk)\n\n            case ToolCallChunk():\n                if model is None:\n                    model = str(chunk.model)\n                tool_calls.extend(_build_tool_calls(chunk))\n                finish_reason = chunk.finish_reason\n                prompt_eval_count, eval_count = _get_usage(chunk)\n\n    combined_text = \"\".join(text_parts)\n    combined_thinking = \"\".join(thinking_parts) if thinking_parts else None\n    assert model is not None\n\n    yield OllamaChatResponse(\n        model=model,\n        message=OllamaMessage(\n            role=\"assistant\",\n            content=combined_text,\n            thinking=combined_thinking,\n            tool_calls=tool_calls if tool_calls else None,\n        ),\n        done=True,\n        done_reason=_map_done_reason(finish_reason),\n        prompt_eval_count=prompt_eval_count,\n        eval_count=eval_count,\n    ).model_dump_json(exclude_none=True)\n    return\n\n\n# ── /api/generate ──\n\n\ndef ollama_generate_request_to_text_generation(\n    request: OllamaGenerateRequest,\n) -> TextGenerationTaskParams:\n    \"\"\"Convert Ollama generate request to exo's internal text generation format.\"\"\"\n    chat_template_messages: list[dict[str, Any]] = []\n    if request.system:\n        chat_template_messages.append({\"role\": \"system\", \"content\": request.system})\n    chat_template_messages.append({\"role\": \"user\", \"content\": request.prompt})\n\n    options = request.options\n    return TextGenerationTaskParams(\n        model=request.model,\n        input=[InputMessage(role=\"user\", content=request.prompt)],\n        instructions=request.system,\n        max_output_tokens=options.num_predict if options else None,\n        temperature=options.temperature if options else None,\n        top_p=options.top_p if options else None,\n        top_k=options.top_k if options else None,\n        stop=options.stop if options else None,\n        seed=options.seed if options else None,\n        stream=request.stream,\n        enable_thinking=request.think,\n        chat_template_messages=chat_template_messages\n        if chat_template_messages\n        else None,\n    )\n\n\nasync def generate_ollama_generate_stream(\n    _command_id: CommandId,\n    chunk_stream: AsyncGenerator[\n        ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None\n    ],\n) -> AsyncGenerator[str, None]:\n    \"\"\"Generate streaming responses for /api/generate in Ollama NDJSON format.\"\"\"\n    thinking_parts: list[str] = []\n\n    async for chunk in chunk_stream:\n        match chunk:\n            case PrefillProgressChunk():\n                continue\n\n            case ErrorChunk():\n                resp = OllamaGenerateResponse(\n                    model=str(chunk.model),\n                    response=\"\",\n                    done=True,\n                    done_reason=\"error\",\n                )\n                yield f\"{resp.model_dump_json(exclude_none=True)}\\n\"\n                return\n\n            case ToolCallChunk():\n                # generate endpoint doesn't support tools; emit as done\n                prompt_eval, eval_count = _get_usage(chunk)\n                resp = OllamaGenerateResponse(\n                    model=str(chunk.model),\n                    response=\"\",\n                    done=True,\n                    done_reason=\"stop\",\n                    prompt_eval_count=prompt_eval,\n                    eval_count=eval_count,\n                )\n                yield f\"{resp.model_dump_json(exclude_none=True)}\\n\"\n                return\n\n            case TokenChunk():\n                done = chunk.finish_reason is not None\n\n                if chunk.is_thinking:\n                    thinking_parts.append(chunk.text)\n                    resp = OllamaGenerateResponse(\n                        model=str(chunk.model),\n                        response=\"\",\n                        thinking=chunk.text,\n                        done=False,\n                    )\n                    yield f\"{resp.model_dump_json(exclude_none=True)}\\n\"\n                elif done:\n                    prompt_eval, eval_count = _get_usage(chunk)\n                    resp = OllamaGenerateResponse(\n                        model=str(chunk.model),\n                        response=chunk.text,\n                        done=True,\n                        done_reason=_map_done_reason(chunk.finish_reason),\n                        prompt_eval_count=prompt_eval,\n                        eval_count=eval_count,\n                    )\n                    yield f\"{resp.model_dump_json(exclude_none=True)}\\n\"\n                else:\n                    resp = OllamaGenerateResponse(\n                        model=str(chunk.model),\n                        response=chunk.text,\n                        done=False,\n                    )\n                    yield f\"{resp.model_dump_json(exclude_none=True)}\\n\"\n\n                if done:\n                    return\n\n\nasync def collect_ollama_generate_response(\n    _command_id: CommandId,\n    chunk_stream: AsyncGenerator[\n        ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None\n    ],\n) -> AsyncGenerator[str]:\n    \"\"\"Collect chunks into a single non-streaming /api/generate response.\"\"\"\n    text_parts: list[str] = []\n    thinking_parts: list[str] = []\n    model: str | None = None\n    finish_reason: str | None = None\n    prompt_eval_count: int | None = None\n    eval_count: int | None = None\n\n    async for chunk in chunk_stream:\n        match chunk:\n            case PrefillProgressChunk():\n                continue\n            case ErrorChunk():\n                raise ValueError(chunk.error_message or \"Internal server error\")\n            case TokenChunk():\n                if model is None:\n                    model = str(chunk.model)\n                if chunk.is_thinking:\n                    thinking_parts.append(chunk.text)\n                else:\n                    text_parts.append(chunk.text)\n                if chunk.finish_reason is not None:\n                    finish_reason = chunk.finish_reason\n                    prompt_eval_count, eval_count = _get_usage(chunk)\n            case ToolCallChunk():\n                if model is None:\n                    model = str(chunk.model)\n                finish_reason = chunk.finish_reason\n                prompt_eval_count, eval_count = _get_usage(chunk)\n\n    assert model is not None\n    yield OllamaGenerateResponse(\n        model=model,\n        response=\"\".join(text_parts),\n        thinking=\"\".join(thinking_parts) if thinking_parts else None,\n        done=True,\n        done_reason=_map_done_reason(finish_reason),\n        prompt_eval_count=prompt_eval_count,\n        eval_count=eval_count,\n    ).model_dump_json(exclude_none=True)\n    return\n"
  },
  {
    "path": "src/exo/api/adapters/responses.py",
    "content": "\"\"\"OpenAI Responses API adapter for converting requests/responses.\"\"\"\n\nfrom collections.abc import AsyncGenerator\nfrom itertools import count\nfrom typing import Any\n\nfrom exo.api.types import Usage\nfrom exo.api.types.openai_responses import (\n    FunctionCallInputItem,\n    ResponseCompletedEvent,\n    ResponseContentPart,\n    ResponseContentPartAddedEvent,\n    ResponseContentPartDoneEvent,\n    ResponseCreatedEvent,\n    ResponseFunctionCallArgumentsDeltaEvent,\n    ResponseFunctionCallArgumentsDoneEvent,\n    ResponseFunctionCallItem,\n    ResponseInProgressEvent,\n    ResponseInputMessage,\n    ResponseItem,\n    ResponseMessageItem,\n    ResponseOutputItemAddedEvent,\n    ResponseOutputItemDoneEvent,\n    ResponseOutputText,\n    ResponseReasoningItem,\n    ResponseReasoningSummaryPartAddedEvent,\n    ResponseReasoningSummaryPartDoneEvent,\n    ResponseReasoningSummaryText,\n    ResponseReasoningSummaryTextDeltaEvent,\n    ResponseReasoningSummaryTextDoneEvent,\n    ResponsesRequest,\n    ResponsesResponse,\n    ResponsesStreamEvent,\n    ResponseTextDeltaEvent,\n    ResponseTextDoneEvent,\n    ResponseUsage,\n)\nfrom exo.shared.types.chunks import (\n    ErrorChunk,\n    PrefillProgressChunk,\n    TokenChunk,\n    ToolCallChunk,\n)\nfrom exo.shared.types.common import CommandId\nfrom exo.shared.types.text_generation import (\n    InputMessage,\n    TextGenerationTaskParams,\n    resolve_reasoning_params,\n)\n\n\ndef _format_sse(event: ResponsesStreamEvent) -> str:\n    \"\"\"Format a streaming event as an SSE message.\"\"\"\n    return f\"event: {event.type}\\ndata: {event.model_dump_json()}\\n\\n\"\n\n\ndef _extract_content(content: str | list[ResponseContentPart]) -> str:\n    \"\"\"Extract plain text from a content field that may be a string or list of parts.\"\"\"\n    if isinstance(content, str):\n        return content\n    return \"\".join(part.text for part in content)\n\n\ndef responses_request_to_text_generation(\n    request: ResponsesRequest,\n) -> TextGenerationTaskParams:\n    input_value: list[InputMessage]\n    built_chat_template: list[dict[str, Any]] | None = None\n    if isinstance(request.input, str):\n        input_value = [InputMessage(role=\"user\", content=request.input)]\n    else:\n        input_messages: list[InputMessage] = []\n        chat_template_messages: list[dict[str, Any]] = []\n\n        if request.instructions is not None:\n            chat_template_messages.append(\n                {\"role\": \"system\", \"content\": request.instructions}\n            )\n\n        for item in request.input:\n            if isinstance(item, ResponseInputMessage):\n                content = _extract_content(item.content)\n                if item.role in (\"user\", \"assistant\", \"developer\"):\n                    input_messages.append(InputMessage(role=item.role, content=content))\n                if item.role == \"system\":\n                    chat_template_messages.append(\n                        {\"role\": \"system\", \"content\": content}\n                    )\n                else:\n                    chat_template_messages.append(\n                        {\"role\": item.role, \"content\": content}\n                    )\n            elif isinstance(item, FunctionCallInputItem):\n                chat_template_messages.append(\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": \"\",\n                        \"tool_calls\": [\n                            {\n                                \"id\": item.call_id,\n                                \"type\": \"function\",\n                                \"function\": {\n                                    \"name\": item.name,\n                                    \"arguments\": item.arguments,\n                                },\n                            }\n                        ],\n                    }\n                )\n            else:\n                chat_template_messages.append(\n                    {\n                        \"role\": \"tool\",\n                        \"tool_call_id\": item.call_id,\n                        \"content\": item.output,\n                    }\n                )\n\n        input_value = (\n            input_messages\n            if input_messages\n            else [InputMessage(role=\"user\", content=\"\")]\n        )\n        built_chat_template = chat_template_messages if chat_template_messages else None\n\n    effort_from_reasoning = request.reasoning.effort if request.reasoning else None\n    resolved_effort, resolved_thinking = resolve_reasoning_params(\n        effort_from_reasoning, request.enable_thinking\n    )\n\n    # The responses API often does not provide tool args nested under a \"function\" field.\n    # Since we follow the chat completions format of tools in the backend (for MLX chat templates)\n    # we need to normalise to this format.\n    normalised_tools: list[dict[str, Any]] | None = None\n    if request.tools:\n        normalised_tools = []\n        for tool in request.tools:\n            if \"function\" in tool:\n                normalised_tools.append(tool)\n            else:\n                normalised_tools.append(\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": tool.get(\"name\", \"\"),\n                            \"description\": tool.get(\"description\", \"\"),\n                            \"parameters\": tool.get(\"parameters\", {}),\n                            **({\"strict\": tool[\"strict\"]} if \"strict\" in tool else {}),\n                        },\n                    }\n                )\n\n    return TextGenerationTaskParams(\n        model=request.model,\n        input=input_value,\n        instructions=request.instructions,\n        max_output_tokens=request.max_output_tokens,\n        temperature=request.temperature,\n        top_p=request.top_p,\n        stream=request.stream,\n        tools=normalised_tools,\n        top_k=request.top_k,\n        stop=request.stop,\n        seed=request.seed,\n        chat_template_messages=built_chat_template or request.chat_template_messages,\n        reasoning_effort=resolved_effort,\n        enable_thinking=resolved_thinking,\n    )\n\n\nasync def collect_responses_response(\n    command_id: CommandId,\n    model: str,\n    chunk_stream: AsyncGenerator[\n        ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None\n    ],\n) -> AsyncGenerator[str]:\n    # This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because\n    # FastAPI handles the cancellation better but wouldn't auto-serialize for some reason\n    \"\"\"Collect all token chunks and return a single ResponsesResponse.\"\"\"\n    response_id = f\"resp_{command_id}\"\n    item_id = f\"item_{command_id}\"\n    reasoning_id = f\"rs_{command_id}\"\n    accumulated_text = \"\"\n    thinking_parts: list[str] = []\n    function_call_items: list[ResponseFunctionCallItem] = []\n    last_usage: Usage | None = None\n    error_message: str | None = None\n\n    async for chunk in chunk_stream:\n        if isinstance(chunk, PrefillProgressChunk):\n            continue\n\n        if isinstance(chunk, ErrorChunk):\n            error_message = chunk.error_message or \"Internal server error\"\n            break\n\n        last_usage = chunk.usage or last_usage\n\n        if isinstance(chunk, ToolCallChunk):\n            for tool in chunk.tool_calls:\n                function_call_items.append(\n                    ResponseFunctionCallItem(\n                        id=tool.id,\n                        call_id=tool.id,\n                        name=tool.name,\n                        arguments=tool.arguments,\n                    )\n                )\n            continue\n\n        if chunk.is_thinking:\n            thinking_parts.append(chunk.text)\n            continue\n\n        accumulated_text += chunk.text\n\n    if error_message is not None:\n        raise ValueError(error_message)\n\n    # Create usage from usage data if available\n    usage = None\n    if last_usage is not None:\n        usage = ResponseUsage(\n            input_tokens=last_usage.prompt_tokens,\n            output_tokens=last_usage.completion_tokens,\n            total_tokens=last_usage.total_tokens,\n        )\n\n    output: list[ResponseItem] = []\n    if thinking_parts:\n        output.append(\n            ResponseReasoningItem(\n                id=reasoning_id,\n                summary=[ResponseReasoningSummaryText(text=\"\".join(thinking_parts))],\n            )\n        )\n    output.append(\n        ResponseMessageItem(\n            id=item_id,\n            content=[ResponseOutputText(text=accumulated_text)],\n            status=\"completed\",\n        )\n    )\n    output.extend(function_call_items)\n\n    yield ResponsesResponse(\n        id=response_id,\n        model=model,\n        status=\"completed\",\n        output=output,\n        output_text=accumulated_text,\n        usage=usage,\n    ).model_dump_json()\n    return\n\n\nasync def generate_responses_stream(\n    command_id: CommandId,\n    model: str,\n    chunk_stream: AsyncGenerator[\n        ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None\n    ],\n) -> AsyncGenerator[str, None]:\n    \"\"\"Generate OpenAI Responses API streaming events from TokenChunks.\"\"\"\n    response_id = f\"resp_{command_id}\"\n    item_id = f\"item_{command_id}\"\n    reasoning_id = f\"rs_{command_id}\"\n    seq = count(1)\n\n    # response.created\n    initial_response = ResponsesResponse(\n        id=response_id,\n        model=model,\n        status=\"in_progress\",\n        output=[],\n        output_text=\"\",\n    )\n    created_event = ResponseCreatedEvent(\n        sequence_number=next(seq), response=initial_response\n    )\n    yield _format_sse(created_event)\n\n    # response.in_progress\n    in_progress_event = ResponseInProgressEvent(\n        sequence_number=next(seq), response=initial_response\n    )\n    yield _format_sse(in_progress_event)\n\n    accumulated_text = \"\"\n    accumulated_thinking = \"\"\n    function_call_items: list[ResponseFunctionCallItem] = []\n    last_usage: Usage | None = None\n    next_output_index = 0\n\n    # Track dynamic block creation\n    reasoning_started = False\n    reasoning_output_index = -1\n    message_started = False\n    message_output_index = -1\n\n    async for chunk in chunk_stream:\n        if isinstance(chunk, PrefillProgressChunk):\n            continue\n\n        if isinstance(chunk, ErrorChunk):\n            break\n\n        last_usage = chunk.usage or last_usage\n\n        if isinstance(chunk, ToolCallChunk):\n            for tool in chunk.tool_calls:\n                fc_id = f\"fc_{tool.id}\"\n                call_id = f\"call_{tool.id}\"\n\n                # response.output_item.added for function_call\n                fc_item = ResponseFunctionCallItem(\n                    id=fc_id,\n                    call_id=call_id,\n                    name=tool.name,\n                    arguments=\"\",\n                    status=\"in_progress\",\n                )\n                fc_added = ResponseOutputItemAddedEvent(\n                    sequence_number=next(seq),\n                    output_index=next_output_index,\n                    item=fc_item,\n                )\n                yield _format_sse(fc_added)\n\n                # response.function_call_arguments.delta\n                args_delta = ResponseFunctionCallArgumentsDeltaEvent(\n                    sequence_number=next(seq),\n                    item_id=fc_id,\n                    output_index=next_output_index,\n                    delta=tool.arguments,\n                )\n                yield _format_sse(args_delta)\n\n                # response.function_call_arguments.done\n                args_done = ResponseFunctionCallArgumentsDoneEvent(\n                    sequence_number=next(seq),\n                    item_id=fc_id,\n                    output_index=next_output_index,\n                    name=tool.name,\n                    arguments=tool.arguments,\n                )\n                yield _format_sse(args_done)\n\n                # response.output_item.done\n                fc_done_item = ResponseFunctionCallItem(\n                    id=fc_id,\n                    call_id=call_id,\n                    name=tool.name,\n                    arguments=tool.arguments,\n                    status=\"completed\",\n                )\n                fc_item_done = ResponseOutputItemDoneEvent(\n                    sequence_number=next(seq),\n                    output_index=next_output_index,\n                    item=fc_done_item,\n                )\n                yield _format_sse(fc_item_done)\n\n                function_call_items.append(fc_done_item)\n                next_output_index += 1\n            continue\n\n        if chunk.is_thinking:\n            # Start reasoning block on first thinking token\n            if not reasoning_started:\n                reasoning_started = True\n                reasoning_output_index = next_output_index\n                next_output_index += 1\n\n                # response.output_item.added for reasoning\n                reasoning_item = ResponseReasoningItem(\n                    id=reasoning_id,\n                    summary=[],\n                    status=\"in_progress\",\n                )\n                rs_added = ResponseOutputItemAddedEvent(\n                    sequence_number=next(seq),\n                    output_index=reasoning_output_index,\n                    item=reasoning_item,\n                )\n                yield _format_sse(rs_added)\n\n                # response.reasoning_summary_part.added\n                part_added = ResponseReasoningSummaryPartAddedEvent(\n                    sequence_number=next(seq),\n                    item_id=reasoning_id,\n                    output_index=reasoning_output_index,\n                    summary_index=0,\n                    part=ResponseReasoningSummaryText(text=\"\"),\n                )\n                yield _format_sse(part_added)\n\n            accumulated_thinking += chunk.text\n\n            # response.reasoning_summary_text.delta\n            rs_delta = ResponseReasoningSummaryTextDeltaEvent(\n                sequence_number=next(seq),\n                item_id=reasoning_id,\n                output_index=reasoning_output_index,\n                summary_index=0,\n                delta=chunk.text,\n            )\n            yield _format_sse(rs_delta)\n            continue\n\n        # Close reasoning block when transitioning to text\n        if reasoning_started and not message_started:\n            # response.reasoning_summary_text.done\n            rs_text_done = ResponseReasoningSummaryTextDoneEvent(\n                sequence_number=next(seq),\n                item_id=reasoning_id,\n                output_index=reasoning_output_index,\n                summary_index=0,\n                text=accumulated_thinking,\n            )\n            yield _format_sse(rs_text_done)\n\n            # response.reasoning_summary_part.done\n            rs_part_done = ResponseReasoningSummaryPartDoneEvent(\n                sequence_number=next(seq),\n                item_id=reasoning_id,\n                output_index=reasoning_output_index,\n                summary_index=0,\n                part=ResponseReasoningSummaryText(text=accumulated_thinking),\n            )\n            yield _format_sse(rs_part_done)\n\n            # response.output_item.done for reasoning\n            rs_item_done = ResponseOutputItemDoneEvent(\n                sequence_number=next(seq),\n                output_index=reasoning_output_index,\n                item=ResponseReasoningItem(\n                    id=reasoning_id,\n                    summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],\n                ),\n            )\n            yield _format_sse(rs_item_done)\n\n        # Start message block on first text token\n        if not message_started:\n            message_started = True\n            message_output_index = next_output_index\n            next_output_index += 1\n\n            initial_item = ResponseMessageItem(\n                id=item_id,\n                content=[ResponseOutputText(text=\"\")],\n                status=\"in_progress\",\n            )\n            item_added = ResponseOutputItemAddedEvent(\n                sequence_number=next(seq),\n                output_index=message_output_index,\n                item=initial_item,\n            )\n            yield _format_sse(item_added)\n\n            initial_part = ResponseOutputText(text=\"\")\n            part_added = ResponseContentPartAddedEvent(\n                sequence_number=next(seq),\n                item_id=item_id,\n                output_index=message_output_index,\n                content_index=0,\n                part=initial_part,\n            )\n            yield _format_sse(part_added)\n\n        accumulated_text += chunk.text\n\n        # response.output_text.delta\n        delta_event = ResponseTextDeltaEvent(\n            sequence_number=next(seq),\n            item_id=item_id,\n            output_index=message_output_index,\n            content_index=0,\n            delta=chunk.text,\n        )\n        yield _format_sse(delta_event)\n\n    # Close reasoning block if it was never followed by text\n    if reasoning_started and not message_started:\n        rs_text_done = ResponseReasoningSummaryTextDoneEvent(\n            sequence_number=next(seq),\n            item_id=reasoning_id,\n            output_index=reasoning_output_index,\n            summary_index=0,\n            text=accumulated_thinking,\n        )\n        yield _format_sse(rs_text_done)\n\n        rs_part_done = ResponseReasoningSummaryPartDoneEvent(\n            sequence_number=next(seq),\n            item_id=reasoning_id,\n            output_index=reasoning_output_index,\n            summary_index=0,\n            part=ResponseReasoningSummaryText(text=accumulated_thinking),\n        )\n        yield _format_sse(rs_part_done)\n\n        rs_item_done = ResponseOutputItemDoneEvent(\n            sequence_number=next(seq),\n            output_index=reasoning_output_index,\n            item=ResponseReasoningItem(\n                id=reasoning_id,\n                summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],\n            ),\n        )\n        yield _format_sse(rs_item_done)\n\n    # If no message block was started, create one now (empty text)\n    if not message_started:\n        message_output_index = next_output_index\n        next_output_index += 1\n\n        initial_item = ResponseMessageItem(\n            id=item_id,\n            content=[ResponseOutputText(text=\"\")],\n            status=\"in_progress\",\n        )\n        item_added = ResponseOutputItemAddedEvent(\n            sequence_number=next(seq),\n            output_index=message_output_index,\n            item=initial_item,\n        )\n        yield _format_sse(item_added)\n\n        initial_part = ResponseOutputText(text=\"\")\n        part_added_evt = ResponseContentPartAddedEvent(\n            sequence_number=next(seq),\n            item_id=item_id,\n            output_index=message_output_index,\n            content_index=0,\n            part=initial_part,\n        )\n        yield _format_sse(part_added_evt)\n\n    # response.output_text.done\n    text_done = ResponseTextDoneEvent(\n        sequence_number=next(seq),\n        item_id=item_id,\n        output_index=message_output_index,\n        content_index=0,\n        text=accumulated_text,\n    )\n    yield _format_sse(text_done)\n\n    # response.content_part.done\n    final_part = ResponseOutputText(text=accumulated_text)\n    part_done = ResponseContentPartDoneEvent(\n        sequence_number=next(seq),\n        item_id=item_id,\n        output_index=message_output_index,\n        content_index=0,\n        part=final_part,\n    )\n    yield _format_sse(part_done)\n\n    # response.output_item.done\n    final_message_item = ResponseMessageItem(\n        id=item_id,\n        content=[ResponseOutputText(text=accumulated_text)],\n        status=\"completed\",\n    )\n    item_done = ResponseOutputItemDoneEvent(\n        sequence_number=next(seq),\n        output_index=message_output_index,\n        item=final_message_item,\n    )\n    yield _format_sse(item_done)\n\n    # Create usage from usage data if available\n    usage = None\n    if last_usage is not None:\n        usage = ResponseUsage(\n            input_tokens=last_usage.prompt_tokens,\n            output_tokens=last_usage.completion_tokens,\n            total_tokens=last_usage.total_tokens,\n        )\n\n    # response.completed\n    output: list[ResponseItem] = []\n    if reasoning_started:\n        output.append(\n            ResponseReasoningItem(\n                id=reasoning_id,\n                summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],\n            )\n        )\n    output.append(final_message_item)\n    output.extend(function_call_items)\n    final_response = ResponsesResponse(\n        id=response_id,\n        model=model,\n        status=\"completed\",\n        output=output,\n        output_text=accumulated_text,\n        usage=usage,\n    )\n    completed_event = ResponseCompletedEvent(\n        sequence_number=next(seq), response=final_response\n    )\n    yield _format_sse(completed_event)\n"
  },
  {
    "path": "src/exo/api/main.py",
    "content": "import base64\nimport contextlib\nimport json\nimport random\nimport time\nfrom collections.abc import AsyncGenerator, Awaitable, Callable, Iterable\nfrom datetime import datetime, timezone\nfrom http import HTTPStatus\nfrom pathlib import Path\nfrom typing import Annotated, Literal, cast\nfrom uuid import uuid4\n\nimport anyio\nfrom anyio import BrokenResourceError\nfrom fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile\nfrom fastapi.middleware.cors import CORSMiddleware\nfrom fastapi.responses import FileResponse, JSONResponse, StreamingResponse\nfrom fastapi.staticfiles import StaticFiles\nfrom hypercorn.asyncio import serve  # pyright: ignore[reportUnknownVariableType]\nfrom hypercorn.config import Config\nfrom hypercorn.typing import ASGIFramework\nfrom loguru import logger\n\nfrom exo.api.adapters.chat_completions import (\n    chat_request_to_text_generation,\n    collect_chat_response,\n    generate_chat_stream,\n)\nfrom exo.api.adapters.claude import (\n    claude_request_to_text_generation,\n    collect_claude_response,\n    generate_claude_stream,\n)\nfrom exo.api.adapters.ollama import (\n    collect_ollama_chat_response,\n    collect_ollama_generate_response,\n    generate_ollama_chat_stream,\n    generate_ollama_generate_stream,\n    ollama_generate_request_to_text_generation,\n    ollama_request_to_text_generation,\n)\nfrom exo.api.adapters.responses import (\n    collect_responses_response,\n    generate_responses_stream,\n    responses_request_to_text_generation,\n)\nfrom exo.api.types import (\n    AddCustomModelParams,\n    AdvancedImageParams,\n    BenchChatCompletionRequest,\n    BenchChatCompletionResponse,\n    BenchImageGenerationResponse,\n    BenchImageGenerationTaskParams,\n    CancelCommandResponse,\n    ChatCompletionChoice,\n    ChatCompletionMessage,\n    ChatCompletionRequest,\n    ChatCompletionResponse,\n    CreateInstanceParams,\n    CreateInstanceResponse,\n    DeleteDownloadResponse,\n    DeleteInstanceResponse,\n    DeleteTracesRequest,\n    DeleteTracesResponse,\n    ErrorInfo,\n    ErrorResponse,\n    FinishReason,\n    GenerationStats,\n    HuggingFaceSearchResult,\n    ImageData,\n    ImageEditsTaskParams,\n    ImageGenerationResponse,\n    ImageGenerationStats,\n    ImageGenerationTaskParams,\n    ImageListItem,\n    ImageListResponse,\n    ImageSize,\n    ModelList,\n    ModelListModel,\n    PlaceInstanceParams,\n    PlacementPreview,\n    PlacementPreviewResponse,\n    StartDownloadParams,\n    StartDownloadResponse,\n    ToolCall,\n    TraceCategoryStats,\n    TraceEventResponse,\n    TraceListItem,\n    TraceListResponse,\n    TraceRankStats,\n    TraceResponse,\n    TraceStatsResponse,\n    normalize_image_size,\n)\nfrom exo.api.types.claude_api import (\n    ClaudeMessagesRequest,\n    ClaudeMessagesResponse,\n)\nfrom exo.api.types.ollama_api import (\n    OllamaChatRequest,\n    OllamaChatResponse,\n    OllamaGenerateRequest,\n    OllamaGenerateResponse,\n    OllamaModelDetails,\n    OllamaModelTag,\n    OllamaPsModel,\n    OllamaPsResponse,\n    OllamaShowRequest,\n    OllamaShowResponse,\n    OllamaTagsResponse,\n)\nfrom exo.api.types.openai_responses import (\n    ResponsesRequest,\n    ResponsesResponse,\n)\nfrom exo.master.image_store import ImageStore\nfrom exo.master.placement import place_instance as get_instance_placements\nfrom exo.shared.apply import apply\nfrom exo.shared.constants import (\n    DASHBOARD_DIR,\n    EXO_CACHE_HOME,\n    EXO_EVENT_LOG_DIR,\n    EXO_IMAGE_CACHE_DIR,\n    EXO_MAX_CHUNK_SIZE,\n    EXO_TRACING_CACHE_DIR,\n)\nfrom exo.shared.election import ElectionMessage\nfrom exo.shared.logging import InterceptLogger\nfrom exo.shared.models.model_cards import (\n    ModelCard,\n    ModelId,\n    delete_custom_card,\n    get_model_cards,\n    is_custom_card,\n)\nfrom exo.shared.tracing import TraceEvent, compute_stats, export_trace, load_trace_file\nfrom exo.shared.types.chunks import (\n    ErrorChunk,\n    ImageChunk,\n    InputImageChunk,\n    PrefillProgressChunk,\n    TokenChunk,\n    ToolCallChunk,\n)\nfrom exo.shared.types.commands import (\n    Command,\n    CreateInstance,\n    DeleteDownload,\n    DeleteInstance,\n    DownloadCommand,\n    ForwarderCommand,\n    ForwarderDownloadCommand,\n    ImageEdits,\n    ImageGeneration,\n    PlaceInstance,\n    SendInputChunk,\n    StartDownload,\n    TaskCancelled,\n    TaskFinished,\n    TextGeneration,\n)\nfrom exo.shared.types.common import CommandId, Id, NodeId, SystemId\nfrom exo.shared.types.events import (\n    ChunkGenerated,\n    Event,\n    IndexedEvent,\n    TracesMerged,\n)\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.state import State\nfrom exo.shared.types.worker.downloads import DownloadCompleted\nfrom exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta\nfrom exo.shared.types.worker.shards import Sharding\nfrom exo.utils.banner import print_startup_banner\nfrom exo.utils.channels import Receiver, Sender, channel\nfrom exo.utils.disk_event_log import DiskEventLog\nfrom exo.utils.power_sampler import PowerSampler\nfrom exo.utils.task_group import TaskGroup\n\n_API_EVENT_LOG_DIR = EXO_EVENT_LOG_DIR / \"api\"\nONBOARDING_COMPLETE_FILE = EXO_CACHE_HOME / \"onboarding_complete\"\n\n\ndef _format_to_content_type(image_format: Literal[\"png\", \"jpeg\", \"webp\"] | None) -> str:\n    return f\"image/{image_format or 'png'}\"\n\n\ndef _ensure_seed(params: AdvancedImageParams | None) -> AdvancedImageParams:\n    \"\"\"Ensure advanced params has a seed set for distributed consistency.\"\"\"\n    if params is None:\n        return AdvancedImageParams(seed=random.randint(0, 2**32 - 1))\n    if params.seed is None:\n        return params.model_copy(update={\"seed\": random.randint(0, 2**32 - 1)})\n    return params\n\n\nclass API:\n    def __init__(\n        self,\n        node_id: NodeId,\n        *,\n        port: int,\n        event_receiver: Receiver[IndexedEvent],\n        command_sender: Sender[ForwarderCommand],\n        download_command_sender: Sender[ForwarderDownloadCommand],\n        # This lets us pause the API if an election is running\n        election_receiver: Receiver[ElectionMessage],\n    ) -> None:\n        self.state = State()\n        self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)\n        self._system_id = SystemId()\n        self.command_sender = command_sender\n        self.download_command_sender = download_command_sender\n        self.event_receiver = event_receiver\n        self.election_receiver = election_receiver\n        self.node_id: NodeId = node_id\n        self.last_completed_election: int = 0\n        self.port = port\n\n        self.paused: bool = False\n        self.paused_ev: anyio.Event = anyio.Event()\n\n        self.app = FastAPI()\n\n        @self.app.middleware(\"http\")\n        async def _log_requests(  # pyright: ignore[reportUnusedFunction]\n            request: Request,\n            call_next: Callable[[Request], Awaitable[StreamingResponse]],\n        ) -> StreamingResponse:\n            logger.debug(f\"API request: {request.method} {request.url.path}\")\n            return await call_next(request)\n\n        self._setup_exception_handlers()\n        self._setup_cors()\n        self._setup_routes()\n\n        self.app.mount(\n            \"/\",\n            StaticFiles(\n                directory=DASHBOARD_DIR,\n                html=True,\n            ),\n            name=\"dashboard\",\n        )\n\n        self._text_generation_queues: dict[\n            CommandId,\n            Sender[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk],\n        ] = {}\n        self._image_generation_queues: dict[\n            CommandId, Sender[ImageChunk | ErrorChunk]\n        ] = {}\n        self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)\n        self._tg: TaskGroup = TaskGroup()\n\n    def reset(self, result_clock: int, event_receiver: Receiver[IndexedEvent]):\n        logger.info(\"Resetting API State\")\n        self._event_log.close()\n        self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)\n        self.state = State()\n        self._system_id = SystemId()\n        self._text_generation_queues = {}\n        self._image_generation_queues = {}\n        self.unpause(result_clock)\n        self.event_receiver.close()\n        self.event_receiver = event_receiver\n        self._tg.start_soon(self._apply_state)\n\n    def unpause(self, result_clock: int):\n        logger.info(\"Unpausing API\")\n        self.last_completed_election = result_clock\n        self.paused = False\n        self.paused_ev.set()\n        self.paused_ev = anyio.Event()\n\n    def _setup_exception_handlers(self) -> None:\n        self.app.exception_handler(HTTPException)(self.http_exception_handler)\n\n    async def http_exception_handler(\n        self, _: Request, exc: HTTPException\n    ) -> JSONResponse:\n        err = ErrorResponse(\n            error=ErrorInfo(\n                message=exc.detail,\n                type=HTTPStatus(exc.status_code).phrase,\n                code=exc.status_code,\n            )\n        )\n        return JSONResponse(err.model_dump(), status_code=exc.status_code)\n\n    def _setup_cors(self) -> None:\n        self.app.add_middleware(\n            CORSMiddleware,\n            allow_origins=[\"*\"],\n            allow_credentials=True,\n            allow_methods=[\"*\"],\n            allow_headers=[\"*\"],\n        )\n\n    def _setup_routes(self) -> None:\n        self.app.get(\"/node_id\")(lambda: self.node_id)\n        self.app.post(\"/instance\")(self.create_instance)\n        self.app.post(\"/place_instance\")(self.place_instance)\n        self.app.get(\"/instance/placement\")(self.get_placement)\n        self.app.get(\"/instance/previews\")(self.get_placement_previews)\n        self.app.get(\"/instance/{instance_id}\")(self.get_instance)\n        self.app.delete(\"/instance/{instance_id}\")(self.delete_instance)\n        self.app.get(\"/models\")(self.get_models)\n        self.app.get(\"/v1/models\")(self.get_models)\n        self.app.post(\"/models/add\")(self.add_custom_model)\n        self.app.delete(\"/models/custom/{model_id:path}\")(self.delete_custom_model)\n        self.app.get(\"/models/search\")(self.search_models)\n        self.app.post(\"/v1/chat/completions\", response_model=None)(\n            self.chat_completions\n        )\n        self.app.post(\"/bench/chat/completions\")(self.bench_chat_completions)\n        self.app.post(\"/v1/images/generations\", response_model=None)(\n            self.image_generations\n        )\n        self.app.post(\"/bench/images/generations\")(self.bench_image_generations)\n        self.app.post(\"/v1/images/edits\", response_model=None)(self.image_edits)\n        self.app.post(\"/bench/images/edits\")(self.bench_image_edits)\n        self.app.get(\"/images\")(self.list_images)\n        self.app.get(\"/images/{image_id}\")(self.get_image)\n        self.app.post(\"/v1/messages\", response_model=None)(self.claude_messages)\n        self.app.post(\"/v1/responses\", response_model=None)(self.openai_responses)\n        self.app.post(\"/v1/cancel/{command_id}\")(self.cancel_command)\n\n        # Ollama API\n        self.app.head(\"/ollama/\")(self.ollama_version)\n        self.app.head(\"/ollama/api/version\")(self.ollama_version)\n        self.app.post(\"/ollama/api/chat\", response_model=None)(self.ollama_chat)\n        self.app.post(\"/ollama/api/api/chat\", response_model=None)(self.ollama_chat)\n        self.app.post(\"/ollama/api/v1/chat\", response_model=None)(self.ollama_chat)\n        self.app.post(\"/ollama/api/generate\", response_model=None)(self.ollama_generate)\n        self.app.get(\"/ollama/api/tags\")(self.ollama_tags)\n        self.app.get(\"/ollama/api/api/tags\")(self.ollama_tags)\n        self.app.get(\"/ollama/api/v1/tags\")(self.ollama_tags)\n        self.app.post(\"/ollama/api/show\")(self.ollama_show)\n        self.app.get(\"/ollama/api/ps\")(self.ollama_ps)\n        self.app.get(\"/ollama/api/version\")(self.ollama_version)\n\n        self.app.get(\"/state\")(lambda: self.state)\n        self.app.get(\"/events\")(self.stream_events)\n        self.app.post(\"/download/start\")(self.start_download)\n        self.app.delete(\"/download/{node_id}/{model_id:path}\")(self.delete_download)\n        self.app.get(\"/v1/traces\")(self.list_traces)\n        self.app.post(\"/v1/traces/delete\")(self.delete_traces)\n        self.app.get(\"/v1/traces/{task_id}\")(self.get_trace)\n        self.app.get(\"/v1/traces/{task_id}/stats\")(self.get_trace_stats)\n        self.app.get(\"/v1/traces/{task_id}/raw\")(self.get_trace_raw)\n        self.app.get(\"/onboarding\")(self.get_onboarding)\n        self.app.post(\"/onboarding\")(self.complete_onboarding)\n\n    async def place_instance(self, payload: PlaceInstanceParams):\n        command = PlaceInstance(\n            model_card=await ModelCard.load(payload.model_id),\n            sharding=payload.sharding,\n            instance_meta=payload.instance_meta,\n            min_nodes=payload.min_nodes,\n        )\n        await self._send(command)\n\n        return CreateInstanceResponse(\n            message=\"Command received.\",\n            command_id=command.command_id,\n            model_card=command.model_card,\n        )\n\n    async def create_instance(\n        self, payload: CreateInstanceParams\n    ) -> CreateInstanceResponse:\n        instance = payload.instance\n        model_card = await ModelCard.load(instance.shard_assignments.model_id)\n        required_memory = model_card.storage_size\n        available_memory = self._calculate_total_available_memory()\n\n        if required_memory > available_memory:\n            raise HTTPException(\n                status_code=400,\n                detail=f\"Insufficient memory to create instance. Required: {required_memory.in_gb:.1f}GB, Available: {available_memory.in_gb:.1f}GB\",\n            )\n\n        command = CreateInstance(\n            instance=instance,\n        )\n        await self._send(command)\n\n        return CreateInstanceResponse(\n            message=\"Command received.\",\n            command_id=command.command_id,\n            model_card=model_card,\n        )\n\n    async def get_placement(\n        self,\n        model_id: ModelId,\n        sharding: Sharding = Sharding.Pipeline,\n        instance_meta: InstanceMeta = InstanceMeta.MlxRing,\n        min_nodes: int = 1,\n    ) -> Instance:\n        model_card = await ModelCard.load(model_id)\n\n        try:\n            placements = get_instance_placements(\n                PlaceInstance(\n                    model_card=model_card,\n                    sharding=sharding,\n                    instance_meta=instance_meta,\n                    min_nodes=min_nodes,\n                ),\n                node_memory=self.state.node_memory,\n                node_network=self.state.node_network,\n                topology=self.state.topology,\n                current_instances=self.state.instances,\n            )\n        except ValueError as exc:\n            raise HTTPException(status_code=400, detail=str(exc)) from exc\n\n        current_ids = set(self.state.instances.keys())\n        new_ids = [\n            instance_id for instance_id in placements if instance_id not in current_ids\n        ]\n        if len(new_ids) != 1:\n            raise HTTPException(\n                status_code=500,\n                detail=\"Expected exactly one new instance from placement\",\n            )\n\n        return placements[new_ids[0]]\n\n    async def get_placement_previews(\n        self,\n        model_id: ModelId,\n        node_ids: Annotated[list[NodeId] | None, Query()] = None,\n    ) -> PlacementPreviewResponse:\n        seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()\n        previews: list[PlacementPreview] = []\n        required_nodes = set(node_ids) if node_ids else None\n\n        if len(list(self.state.topology.list_nodes())) == 0:\n            return PlacementPreviewResponse(previews=[])\n\n        try:\n            model_card = await ModelCard.load(model_id)\n        except Exception as exc:\n            raise HTTPException(\n                status_code=400, detail=f\"Failed to load model card: {exc}\"\n            ) from exc\n        instance_combinations: list[tuple[Sharding, InstanceMeta, int]] = []\n        for sharding in (Sharding.Pipeline, Sharding.Tensor):\n            for instance_meta in (InstanceMeta.MlxRing, InstanceMeta.MlxJaccl):\n                instance_combinations.extend(\n                    [\n                        (sharding, instance_meta, i)\n                        for i in range(\n                            1, len(list(self.state.topology.list_nodes())) + 1\n                        )\n                    ]\n                )\n        # TODO: PDD\n        # instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))\n\n        for sharding, instance_meta, min_nodes in instance_combinations:\n            try:\n                placements = get_instance_placements(\n                    PlaceInstance(\n                        model_card=model_card,\n                        sharding=sharding,\n                        instance_meta=instance_meta,\n                        min_nodes=min_nodes,\n                    ),\n                    node_memory=self.state.node_memory,\n                    node_network=self.state.node_network,\n                    topology=self.state.topology,\n                    current_instances=self.state.instances,\n                    required_nodes=required_nodes,\n                )\n            except ValueError as exc:\n                if (model_card.model_id, sharding, instance_meta, 0) not in seen:\n                    previews.append(\n                        PlacementPreview(\n                            model_id=model_card.model_id,\n                            sharding=sharding,\n                            instance_meta=instance_meta,\n                            instance=None,\n                            error=str(exc),\n                        )\n                    )\n                seen.add((model_card.model_id, sharding, instance_meta, 0))\n                continue\n\n            current_ids = set(self.state.instances.keys())\n            new_instances = [\n                instance\n                for instance_id, instance in placements.items()\n                if instance_id not in current_ids\n            ]\n\n            if len(new_instances) != 1:\n                if (model_card.model_id, sharding, instance_meta, 0) not in seen:\n                    previews.append(\n                        PlacementPreview(\n                            model_id=model_card.model_id,\n                            sharding=sharding,\n                            instance_meta=instance_meta,\n                            instance=None,\n                            error=\"Expected exactly one new instance from placement\",\n                        )\n                    )\n                seen.add((model_card.model_id, sharding, instance_meta, 0))\n                continue\n\n            instance = new_instances[0]\n            shard_assignments = instance.shard_assignments\n            placement_node_ids = list(shard_assignments.node_to_runner.keys())\n\n            memory_delta_by_node: dict[str, int] = {}\n            if placement_node_ids:\n                total_bytes = model_card.storage_size.in_bytes\n                per_node = total_bytes // len(placement_node_ids)\n                remainder = total_bytes % len(placement_node_ids)\n                for index, node_id in enumerate(sorted(placement_node_ids, key=str)):\n                    extra = 1 if index < remainder else 0\n                    memory_delta_by_node[str(node_id)] = per_node + extra\n\n            if (\n                model_card.model_id,\n                sharding,\n                instance_meta,\n                len(placement_node_ids),\n            ) not in seen:\n                previews.append(\n                    PlacementPreview(\n                        model_id=model_card.model_id,\n                        sharding=sharding,\n                        instance_meta=instance_meta,\n                        instance=instance,\n                        memory_delta_by_node=memory_delta_by_node or None,\n                        error=None,\n                    )\n                )\n            seen.add(\n                (\n                    model_card.model_id,\n                    sharding,\n                    instance_meta,\n                    len(placement_node_ids),\n                )\n            )\n\n        return PlacementPreviewResponse(previews=previews)\n\n    def get_instance(self, instance_id: InstanceId) -> Instance:\n        if instance_id not in self.state.instances:\n            raise HTTPException(status_code=404, detail=\"Instance not found\")\n        return self.state.instances[instance_id]\n\n    async def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse:\n        if instance_id not in self.state.instances:\n            raise HTTPException(status_code=404, detail=\"Instance not found\")\n\n        command = DeleteInstance(\n            instance_id=instance_id,\n        )\n        await self._send(command)\n        return DeleteInstanceResponse(\n            message=\"Command received.\",\n            command_id=command.command_id,\n            instance_id=instance_id,\n        )\n\n    async def cancel_command(self, command_id: CommandId) -> CancelCommandResponse:\n        \"\"\"Cancel an active command by closing its stream and notifying workers.\"\"\"\n        sender = self._text_generation_queues.get(\n            command_id\n        ) or self._image_generation_queues.get(command_id)\n        if sender is None:\n            raise HTTPException(\n                status_code=404,\n                detail=\"Command not found or already completed\",\n            )\n\n        await self._send(TaskCancelled(cancelled_command_id=command_id))\n        sender.close()\n\n        return CancelCommandResponse(\n            message=\"Command cancelled.\",\n            command_id=command_id,\n        )\n\n    async def _token_chunk_stream(\n        self, command_id: CommandId\n    ) -> AsyncGenerator[\n        TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None\n    ]:\n        \"\"\"Yield chunks for a given command until completion.\n\n        This is the internal low-level stream used by all API adapters.\n        \"\"\"\n        try:\n            self._text_generation_queues[command_id], recv = channel[\n                TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk\n            ]()\n\n            with recv as token_chunks:\n                async for chunk in token_chunks:\n                    yield chunk\n                    if isinstance(chunk, PrefillProgressChunk):\n                        continue\n                    if chunk.finish_reason is not None:\n                        break\n\n        except anyio.get_cancelled_exc_class():\n            command = TaskCancelled(cancelled_command_id=command_id)\n            with anyio.CancelScope(shield=True):\n                await self.command_sender.send(\n                    ForwarderCommand(origin=self._system_id, command=command)\n                )\n            raise\n        finally:\n            await self._send(TaskFinished(finished_command_id=command_id))\n            if command_id in self._text_generation_queues:\n                del self._text_generation_queues[command_id]\n\n    async def _collect_text_generation_with_stats(\n        self, command_id: CommandId\n    ) -> BenchChatCompletionResponse:\n        sampler = PowerSampler(get_node_system=lambda: self.state.node_system)\n        text_parts: list[str] = []\n        tool_calls: list[ToolCall] = []\n        model: ModelId | None = None\n        finish_reason: FinishReason | None = None\n\n        stats: GenerationStats | None = None\n\n        async with anyio.create_task_group() as tg:\n            tg.start_soon(sampler.run)\n\n            async for chunk in self._token_chunk_stream(command_id):\n                if isinstance(chunk, PrefillProgressChunk):\n                    continue\n\n                if chunk.finish_reason == \"error\":\n                    raise HTTPException(\n                        status_code=500,\n                        detail=chunk.error_message or \"Internal server error\",\n                    )\n\n                if model is None:\n                    model = chunk.model\n\n                if isinstance(chunk, TokenChunk):\n                    text_parts.append(chunk.text)\n\n                if isinstance(chunk, ToolCallChunk):\n                    tool_calls.extend(\n                        ToolCall(\n                            id=str(uuid4()),\n                            index=i,\n                            function=tool,\n                        )\n                        for i, tool in enumerate(chunk.tool_calls)\n                    )\n\n                stats = chunk.stats or stats\n\n                if chunk.finish_reason is not None:\n                    finish_reason = chunk.finish_reason\n\n            tg.cancel_scope.cancel()\n\n        combined_text = \"\".join(text_parts)\n        assert model is not None\n\n        return BenchChatCompletionResponse(\n            id=command_id,\n            created=int(time.time()),\n            model=model,\n            choices=[\n                ChatCompletionChoice(\n                    index=0,\n                    message=ChatCompletionMessage(\n                        role=\"assistant\",\n                        content=combined_text,\n                        tool_calls=tool_calls if tool_calls else None,\n                    ),\n                    finish_reason=finish_reason,\n                )\n            ],\n            generation_stats=stats,\n            power_usage=sampler.result(),\n        )\n\n    async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:\n        logger.warning(\n            \"TODO: we should send a notification to the user to download the model\"\n        )\n\n    async def chat_completions(\n        self, payload: ChatCompletionRequest\n    ) -> ChatCompletionResponse | StreamingResponse:\n        \"\"\"OpenAI Chat Completions API - adapter.\"\"\"\n        task_params = chat_request_to_text_generation(payload)\n        resolved_model = await self._resolve_and_validate_text_model(\n            ModelId(task_params.model)\n        )\n        task_params = task_params.model_copy(update={\"model\": resolved_model})\n\n        command = TextGeneration(task_params=task_params)\n        await self._send(command)\n\n        if payload.stream:\n            return StreamingResponse(\n                generate_chat_stream(\n                    command.command_id,\n                    self._token_chunk_stream(command.command_id),\n                ),\n                media_type=\"text/event-stream\",\n                headers={\n                    \"Cache-Control\": \"no-cache\",\n                    \"Connection\": \"close\",\n                    \"X-Accel-Buffering\": \"no\",\n                },\n            )\n        else:\n            return StreamingResponse(\n                collect_chat_response(\n                    command.command_id,\n                    self._token_chunk_stream(command.command_id),\n                ),\n                media_type=\"application/json\",\n            )\n\n    async def bench_chat_completions(\n        self, payload: BenchChatCompletionRequest\n    ) -> BenchChatCompletionResponse:\n        task_params = chat_request_to_text_generation(payload)\n        resolved_model = await self._resolve_and_validate_text_model(\n            ModelId(task_params.model)\n        )\n        task_params = task_params.model_copy(update={\"model\": resolved_model})\n\n        task_params = task_params.model_copy(update={\"stream\": False, \"bench\": True})\n\n        command = TextGeneration(task_params=task_params)\n        await self._send(command)\n\n        return await self._collect_text_generation_with_stats(command.command_id)\n\n    async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:\n        \"\"\"Validate a text model exists and return the resolved model ID.\n\n        Raises HTTPException 404 if no instance is found for the model.\n        \"\"\"\n        if not any(\n            instance.shard_assignments.model_id == model_id\n            for instance in self.state.instances.values()\n        ):\n            await self._trigger_notify_user_to_download_model(model_id)\n            raise HTTPException(\n                status_code=404,\n                detail=f\"No instance found for model {model_id}\",\n            )\n        return model_id\n\n    async def _validate_image_model(self, model: ModelId) -> ModelId:\n        \"\"\"Validate model exists and return resolved model ID.\n\n        Raises HTTPException 404 if no instance is found for the model.\n        \"\"\"\n        model_card = await ModelCard.load(model)\n        resolved_model = model_card.model_id\n        if not any(\n            instance.shard_assignments.model_id == resolved_model\n            for instance in self.state.instances.values()\n        ):\n            await self._trigger_notify_user_to_download_model(resolved_model)\n            raise HTTPException(\n                status_code=404, detail=f\"No instance found for model {resolved_model}\"\n            )\n        return resolved_model\n\n    def stream_events(self) -> StreamingResponse:\n        def _generate_json_array(events: Iterable[Event]) -> Iterable[str]:\n            yield \"[\"\n            first = True\n            for event in events:\n                if not first:\n                    yield \",\"\n                first = False\n                yield event.model_dump_json()\n            yield \"]\"\n\n        return StreamingResponse(\n            _generate_json_array(self._event_log.read_all()),\n            media_type=\"application/json\",\n        )\n\n    async def get_image(self, image_id: str) -> FileResponse:\n        stored = self._image_store.get(Id(image_id))\n        if stored is None:\n            raise HTTPException(status_code=404, detail=\"Image not found or expired\")\n        return FileResponse(path=stored.file_path, media_type=stored.content_type)\n\n    async def list_images(self, request: Request) -> ImageListResponse:\n        \"\"\"List all stored images.\"\"\"\n        stored_images = self._image_store.list_images()\n        return ImageListResponse(\n            data=[\n                ImageListItem(\n                    image_id=img.image_id,\n                    url=self._build_image_url(request, img.image_id),\n                    content_type=img.content_type,\n                    expires_at=img.expires_at,\n                )\n                for img in stored_images\n            ]\n        )\n\n    def _build_image_url(self, request: Request, image_id: Id) -> str:\n        host = request.headers.get(\"host\", f\"localhost:{self.port}\")\n        scheme = \"https\" if request.url.scheme == \"https\" else \"http\"\n        return f\"{scheme}://{host}/v1/images/{image_id}\"\n\n    async def image_generations(\n        self, request: Request, payload: ImageGenerationTaskParams\n    ) -> ImageGenerationResponse | StreamingResponse:\n        \"\"\"Handle image generation requests.\n\n        When stream=True and partial_images > 0, returns a StreamingResponse\n        with SSE-formatted events for partial and final images.\n        \"\"\"\n        payload = payload.model_copy(\n            update={\n                \"model\": await self._validate_image_model(ModelId(payload.model)),\n                \"advanced_params\": _ensure_seed(payload.advanced_params),\n            }\n        )\n\n        command = ImageGeneration(\n            task_params=payload,\n        )\n        await self._send(command)\n\n        # Check if streaming is requested\n        if payload.stream and payload.partial_images and payload.partial_images > 0:\n            return StreamingResponse(\n                self._generate_image_stream(\n                    request=request,\n                    command_id=command.command_id,\n                    num_images=payload.n or 1,\n                    response_format=payload.response_format or \"b64_json\",\n                ),\n                media_type=\"text/event-stream\",\n            )\n\n        # Non-streaming: collect all image chunks\n        return await self._collect_image_generation(\n            request=request,\n            command_id=command.command_id,\n            num_images=payload.n or 1,\n            response_format=payload.response_format or \"b64_json\",\n        )\n\n    async def _generate_image_stream(\n        self,\n        request: Request,\n        command_id: CommandId,\n        num_images: int,\n        response_format: str,\n    ) -> AsyncGenerator[str, None]:\n        \"\"\"Generate SSE stream of partial and final images.\"\"\"\n        # Track chunks: {(image_index, is_partial): {chunk_index: data}}\n        image_chunks: dict[tuple[int, bool], dict[int, str]] = {}\n        image_total_chunks: dict[tuple[int, bool], int] = {}\n        image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}\n        images_complete = 0\n\n        try:\n            self._image_generation_queues[command_id], recv = channel[\n                ImageChunk | ErrorChunk\n            ]()\n\n            with recv as chunks:\n                async for chunk in chunks:\n                    if chunk.finish_reason == \"error\":\n                        error_response = ErrorResponse(\n                            error=ErrorInfo(\n                                message=chunk.error_message or \"Internal server error\",\n                                type=\"InternalServerError\",\n                                code=500,\n                            )\n                        )\n                        yield f\"data: {error_response.model_dump_json()}\\n\\n\"\n                        yield \"data: [DONE]\\n\\n\"\n                        return\n\n                    key = (chunk.image_index, chunk.is_partial)\n\n                    if key not in image_chunks:\n                        image_chunks[key] = {}\n                        image_total_chunks[key] = chunk.total_chunks\n                        image_metadata[key] = (\n                            chunk.partial_index,\n                            chunk.total_partials,\n                        )\n\n                    image_chunks[key][chunk.chunk_index] = chunk.data\n\n                    # Check if this image is complete\n                    if len(image_chunks[key]) == image_total_chunks[key]:\n                        full_data = \"\".join(\n                            image_chunks[key][i] for i in range(len(image_chunks[key]))\n                        )\n\n                        partial_idx, total_partials = image_metadata[key]\n\n                        if chunk.is_partial:\n                            # Yield partial image event (always use b64_json for partials)\n                            event_data = {\n                                \"type\": \"partial\",\n                                \"image_index\": chunk.image_index,\n                                \"partial_index\": partial_idx,\n                                \"total_partials\": total_partials,\n                                \"format\": str(chunk.format),\n                                \"data\": {\n                                    \"b64_json\": full_data\n                                    if response_format == \"b64_json\"\n                                    else None,\n                                },\n                            }\n                            yield f\"data: {json.dumps(event_data)}\\n\\n\"\n                        else:\n                            # Final image\n                            if response_format == \"url\":\n                                image_bytes = base64.b64decode(full_data)\n                                content_type = _format_to_content_type(chunk.format)\n                                stored = self._image_store.store(\n                                    image_bytes, content_type\n                                )\n                                url = self._build_image_url(request, stored.image_id)\n                                event_data = {\n                                    \"type\": \"final\",\n                                    \"image_index\": chunk.image_index,\n                                    \"format\": str(chunk.format),\n                                    \"data\": {\"url\": url},\n                                }\n                            else:\n                                event_data = {\n                                    \"type\": \"final\",\n                                    \"image_index\": chunk.image_index,\n                                    \"format\": str(chunk.format),\n                                    \"data\": {\"b64_json\": full_data},\n                                }\n                            yield f\"data: {json.dumps(event_data)}\\n\\n\"\n                            images_complete += 1\n\n                            if images_complete >= num_images:\n                                yield \"data: [DONE]\\n\\n\"\n                                break\n\n                        # Clean up completed image chunks\n                        del image_chunks[key]\n                        del image_total_chunks[key]\n                        del image_metadata[key]\n\n        except anyio.get_cancelled_exc_class():\n            command = TaskCancelled(cancelled_command_id=command_id)\n            with anyio.CancelScope(shield=True):\n                await self.command_sender.send(\n                    ForwarderCommand(origin=self._system_id, command=command)\n                )\n            raise\n        finally:\n            await self._send(TaskFinished(finished_command_id=command_id))\n            if command_id in self._image_generation_queues:\n                del self._image_generation_queues[command_id]\n\n    async def _collect_image_chunks(\n        self,\n        request: Request | None,\n        command_id: CommandId,\n        num_images: int,\n        response_format: str,\n        capture_stats: bool = False,\n    ) -> tuple[list[ImageData], ImageGenerationStats | None]:\n        \"\"\"Collect image chunks and optionally capture stats.\"\"\"\n        # Track chunks per image: {image_index: {chunk_index: data}}\n        # Only track non-partial (final) images\n        image_chunks: dict[int, dict[int, str]] = {}\n        image_total_chunks: dict[int, int] = {}\n        image_formats: dict[int, Literal[\"png\", \"jpeg\", \"webp\"] | None] = {}\n        images_complete = 0\n        stats: ImageGenerationStats | None = None\n\n        try:\n            self._image_generation_queues[command_id], recv = channel[\n                ImageChunk | ErrorChunk\n            ]()\n\n            while images_complete < num_images:\n                with recv as chunks:\n                    async for chunk in chunks:\n                        if chunk.finish_reason == \"error\":\n                            raise HTTPException(\n                                status_code=500,\n                                detail=chunk.error_message or \"Internal server error\",\n                            )\n\n                        if chunk.is_partial:\n                            continue\n\n                        if chunk.image_index not in image_chunks:\n                            image_chunks[chunk.image_index] = {}\n                            image_total_chunks[chunk.image_index] = chunk.total_chunks\n                            image_formats[chunk.image_index] = chunk.format\n\n                        image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data\n\n                        if capture_stats and chunk.stats is not None:\n                            stats = chunk.stats\n\n                        if (\n                            len(image_chunks[chunk.image_index])\n                            == image_total_chunks[chunk.image_index]\n                        ):\n                            images_complete += 1\n\n                        if images_complete >= num_images:\n                            break\n\n            images: list[ImageData] = []\n            for image_idx in range(num_images):\n                chunks_dict = image_chunks[image_idx]\n                full_data = \"\".join(chunks_dict[i] for i in range(len(chunks_dict)))\n                if response_format == \"url\" and request is not None:\n                    image_bytes = base64.b64decode(full_data)\n                    content_type = _format_to_content_type(image_formats.get(image_idx))\n                    stored = self._image_store.store(image_bytes, content_type)\n                    url = self._build_image_url(request, stored.image_id)\n                    images.append(ImageData(b64_json=None, url=url))\n                else:\n                    images.append(\n                        ImageData(\n                            b64_json=full_data\n                            if response_format == \"b64_json\"\n                            else None,\n                            url=None,\n                        )\n                    )\n\n            return (images, stats if capture_stats else None)\n        except anyio.get_cancelled_exc_class():\n            command = TaskCancelled(cancelled_command_id=command_id)\n            with anyio.CancelScope(shield=True):\n                await self.command_sender.send(\n                    ForwarderCommand(origin=self._system_id, command=command)\n                )\n            raise\n        finally:\n            await self._send(TaskFinished(finished_command_id=command_id))\n            if command_id in self._image_generation_queues:\n                del self._image_generation_queues[command_id]\n\n    async def _collect_image_generation(\n        self,\n        request: Request,\n        command_id: CommandId,\n        num_images: int,\n        response_format: str,\n    ) -> ImageGenerationResponse:\n        \"\"\"Collect all image chunks (non-streaming) and return a single response.\"\"\"\n        images, _ = await self._collect_image_chunks(\n            request, command_id, num_images, response_format, capture_stats=False\n        )\n        return ImageGenerationResponse(data=images)\n\n    async def _collect_image_generation_with_stats(\n        self,\n        request: Request | None,\n        command_id: CommandId,\n        num_images: int,\n        response_format: str,\n    ) -> BenchImageGenerationResponse:\n        sampler = PowerSampler(get_node_system=lambda: self.state.node_system)\n        images: list[ImageData] = []\n        stats: ImageGenerationStats | None = None\n        async with anyio.create_task_group() as tg:\n            tg.start_soon(sampler.run)\n            images, stats = await self._collect_image_chunks(\n                request, command_id, num_images, response_format, capture_stats=True\n            )\n            tg.cancel_scope.cancel()\n        return BenchImageGenerationResponse(\n            data=images, generation_stats=stats, power_usage=sampler.result()\n        )\n\n    async def bench_image_generations(\n        self, request: Request, payload: BenchImageGenerationTaskParams\n    ) -> BenchImageGenerationResponse:\n        payload = payload.model_copy(\n            update={\n                \"model\": await self._validate_image_model(ModelId(payload.model)),\n                \"stream\": False,\n                \"partial_images\": 0,\n                \"advanced_params\": _ensure_seed(payload.advanced_params),\n            }\n        )\n\n        command = ImageGeneration(\n            task_params=payload,\n        )\n        await self._send(command)\n\n        return await self._collect_image_generation_with_stats(\n            request=request,\n            command_id=command.command_id,\n            num_images=payload.n or 1,\n            response_format=payload.response_format or \"b64_json\",\n        )\n\n    async def _send_image_edits_command(\n        self,\n        image: UploadFile,\n        prompt: str,\n        model: ModelId,\n        n: int,\n        size: ImageSize,\n        response_format: Literal[\"url\", \"b64_json\"],\n        input_fidelity: Literal[\"low\", \"high\"],\n        stream: bool,\n        partial_images: int,\n        bench: bool,\n        quality: Literal[\"high\", \"medium\", \"low\"],\n        output_format: Literal[\"png\", \"jpeg\", \"webp\"],\n        advanced_params: AdvancedImageParams | None,\n    ) -> ImageEdits:\n        \"\"\"Prepare and send an image edits command with chunked image upload.\"\"\"\n        resolved_model = await self._validate_image_model(model)\n        advanced_params = _ensure_seed(advanced_params)\n\n        image_content = await image.read()\n        image_data = base64.b64encode(image_content).decode(\"utf-8\")\n\n        image_strength = 0.7 if input_fidelity == \"high\" else 0.3\n\n        data_chunks = [\n            image_data[i : i + EXO_MAX_CHUNK_SIZE]\n            for i in range(0, len(image_data), EXO_MAX_CHUNK_SIZE)\n        ]\n        total_chunks = len(data_chunks)\n\n        command = ImageEdits(\n            task_params=ImageEditsTaskParams(\n                image_data=\"\",\n                total_input_chunks=total_chunks,\n                prompt=prompt,\n                model=resolved_model,\n                n=n,\n                size=size,\n                response_format=response_format,\n                image_strength=image_strength,\n                stream=stream,\n                partial_images=partial_images,\n                bench=bench,\n                quality=quality,\n                output_format=output_format,\n                advanced_params=advanced_params,\n            ),\n        )\n\n        logger.info(\n            f\"Sending input image: {len(image_data)} bytes in {total_chunks} chunks\"\n        )\n        for chunk_index, chunk_data in enumerate(data_chunks):\n            await self._send(\n                SendInputChunk(\n                    chunk=InputImageChunk(\n                        model=resolved_model,\n                        command_id=command.command_id,\n                        data=chunk_data,\n                        chunk_index=chunk_index,\n                        total_chunks=total_chunks,\n                    )\n                )\n            )\n\n        await self._send(command)\n        return command\n\n    async def image_edits(\n        self,\n        request: Request,\n        image: UploadFile = File(...),  # noqa: B008\n        prompt: str = Form(...),\n        model: str = Form(...),\n        n: int = Form(1),\n        size: str | None = Form(None),\n        response_format: Literal[\"url\", \"b64_json\"] = Form(\"b64_json\"),\n        input_fidelity: Literal[\"low\", \"high\"] = Form(\"low\"),\n        stream: str = Form(\"false\"),\n        partial_images: str = Form(\"0\"),\n        quality: Literal[\"high\", \"medium\", \"low\"] = Form(\"medium\"),\n        output_format: Literal[\"png\", \"jpeg\", \"webp\"] = Form(\"png\"),\n        advanced_params: str | None = Form(None),\n    ) -> ImageGenerationResponse | StreamingResponse:\n        \"\"\"Handle image editing requests (img2img).\"\"\"\n        # Parse string form values to proper types\n        stream_bool = stream.lower() in (\"true\", \"1\", \"yes\")\n        partial_images_int = int(partial_images) if partial_images.isdigit() else 0\n\n        parsed_advanced_params: AdvancedImageParams | None = None\n        if advanced_params:\n            with contextlib.suppress(Exception):\n                parsed_advanced_params = AdvancedImageParams.model_validate_json(\n                    advanced_params\n                )\n\n        command = await self._send_image_edits_command(\n            image=image,\n            prompt=prompt,\n            model=ModelId(model),\n            n=n,\n            size=normalize_image_size(size),\n            response_format=response_format,\n            input_fidelity=input_fidelity,\n            stream=stream_bool,\n            partial_images=partial_images_int,\n            bench=False,\n            quality=quality,\n            output_format=output_format,\n            advanced_params=parsed_advanced_params,\n        )\n\n        if stream_bool and partial_images_int > 0:\n            return StreamingResponse(\n                self._generate_image_stream(\n                    request=request,\n                    command_id=command.command_id,\n                    num_images=n,\n                    response_format=response_format,\n                ),\n                media_type=\"text/event-stream\",\n            )\n\n        return await self._collect_image_generation(\n            request=request,\n            command_id=command.command_id,\n            num_images=n,\n            response_format=response_format,\n        )\n\n    async def bench_image_edits(\n        self,\n        request: Request,\n        image: UploadFile = File(...),  # noqa: B008\n        prompt: str = Form(...),\n        model: str = Form(...),\n        n: int = Form(1),\n        size: str | None = Form(None),\n        response_format: Literal[\"url\", \"b64_json\"] = Form(\"b64_json\"),\n        input_fidelity: Literal[\"low\", \"high\"] = Form(\"low\"),\n        quality: Literal[\"high\", \"medium\", \"low\"] = Form(\"medium\"),\n        output_format: Literal[\"png\", \"jpeg\", \"webp\"] = Form(\"png\"),\n        advanced_params: str | None = Form(None),\n    ) -> BenchImageGenerationResponse:\n        \"\"\"Handle benchmark image editing requests with generation stats.\"\"\"\n        parsed_advanced_params: AdvancedImageParams | None = None\n        if advanced_params:\n            with contextlib.suppress(Exception):\n                parsed_advanced_params = AdvancedImageParams.model_validate_json(\n                    advanced_params\n                )\n\n        command = await self._send_image_edits_command(\n            image=image,\n            prompt=prompt,\n            model=ModelId(model),\n            n=n,\n            size=normalize_image_size(size),\n            response_format=response_format,\n            input_fidelity=input_fidelity,\n            stream=False,\n            partial_images=0,\n            bench=True,\n            quality=quality,\n            output_format=output_format,\n            advanced_params=parsed_advanced_params,\n        )\n\n        return await self._collect_image_generation_with_stats(\n            request=request,\n            command_id=command.command_id,\n            num_images=n,\n            response_format=response_format,\n        )\n\n    async def claude_messages(\n        self, payload: ClaudeMessagesRequest\n    ) -> ClaudeMessagesResponse | StreamingResponse:\n        \"\"\"Claude Messages API - adapter.\"\"\"\n        task_params = claude_request_to_text_generation(payload)\n        resolved_model = await self._resolve_and_validate_text_model(\n            ModelId(task_params.model)\n        )\n        task_params = task_params.model_copy(update={\"model\": resolved_model})\n\n        command = TextGeneration(task_params=task_params)\n        await self._send(command)\n\n        if payload.stream:\n            return StreamingResponse(\n                generate_claude_stream(\n                    command.command_id,\n                    payload.model,\n                    self._token_chunk_stream(command.command_id),\n                ),\n                media_type=\"text/event-stream\",\n                headers={\n                    \"Cache-Control\": \"no-cache\",\n                    \"Connection\": \"close\",\n                    \"X-Accel-Buffering\": \"no\",\n                },\n            )\n        else:\n            return StreamingResponse(\n                collect_claude_response(\n                    command.command_id,\n                    payload.model,\n                    self._token_chunk_stream(command.command_id),\n                ),\n                media_type=\"application/json\",\n            )\n\n    async def openai_responses(\n        self, payload: ResponsesRequest\n    ) -> ResponsesResponse | StreamingResponse:\n        \"\"\"OpenAI Responses API.\"\"\"\n        task_params = responses_request_to_text_generation(payload)\n        resolved_model = await self._resolve_and_validate_text_model(task_params.model)\n        task_params = task_params.model_copy(update={\"model\": resolved_model})\n\n        command = TextGeneration(task_params=task_params)\n        await self._send(command)\n\n        if payload.stream:\n            return StreamingResponse(\n                generate_responses_stream(\n                    command.command_id,\n                    payload.model,\n                    self._token_chunk_stream(command.command_id),\n                ),\n                media_type=\"text/event-stream\",\n                headers={\n                    \"Cache-Control\": \"no-cache\",\n                    \"Connection\": \"close\",\n                    \"X-Accel-Buffering\": \"no\",\n                },\n            )\n\n        else:\n            return StreamingResponse(\n                collect_responses_response(\n                    command.command_id,\n                    payload.model,\n                    self._token_chunk_stream(command.command_id),\n                ),\n                media_type=\"application/json\",\n            )\n\n    async def _ollama_root(self) -> JSONResponse:\n        \"\"\"Respond to HEAD / from Ollama CLI connectivity checks.\"\"\"\n        return JSONResponse(content=\"Ollama is running\")\n\n    async def ollama_chat(\n        self, request: Request\n    ) -> OllamaChatResponse | StreamingResponse:\n        \"\"\"Ollama Chat API — accepts JSON regardless of Content-Type.\"\"\"\n        body = await request.body()\n        payload = OllamaChatRequest.model_validate_json(body)\n        task_params = ollama_request_to_text_generation(payload)\n        resolved_model = await self._resolve_and_validate_text_model(\n            ModelId(task_params.model)\n        )\n        task_params = task_params.model_copy(update={\"model\": resolved_model})\n\n        command = TextGeneration(task_params=task_params)\n        await self._send(command)\n\n        if payload.stream:\n            return StreamingResponse(\n                generate_ollama_chat_stream(\n                    command.command_id,\n                    self._token_chunk_stream(command.command_id),\n                ),\n                media_type=\"application/x-ndjson\",\n                headers={\n                    \"Cache-Control\": \"no-cache\",\n                    \"Connection\": \"close\",\n                    \"X-Accel-Buffering\": \"no\",\n                },\n            )\n        else:\n            return StreamingResponse(\n                collect_ollama_chat_response(\n                    command.command_id,\n                    self._token_chunk_stream(command.command_id),\n                ),\n                media_type=\"application/json\",\n            )\n\n    async def ollama_generate(\n        self, request: Request\n    ) -> OllamaGenerateResponse | StreamingResponse:\n        \"\"\"Ollama Generate API — accepts JSON regardless of Content-Type.\"\"\"\n        body = await request.body()\n        payload = OllamaGenerateRequest.model_validate_json(body)\n        task_params = ollama_generate_request_to_text_generation(payload)\n        resolved_model = await self._resolve_and_validate_text_model(\n            ModelId(task_params.model)\n        )\n        task_params = task_params.model_copy(update={\"model\": resolved_model})\n\n        command = TextGeneration(task_params=task_params)\n        await self._send(command)\n\n        if payload.stream:\n            return StreamingResponse(\n                generate_ollama_generate_stream(\n                    command.command_id,\n                    self._token_chunk_stream(command.command_id),\n                ),\n                media_type=\"application/x-ndjson\",\n                headers={\n                    \"Cache-Control\": \"no-cache\",\n                    \"Connection\": \"close\",\n                    \"X-Accel-Buffering\": \"no\",\n                },\n            )\n        else:\n            return StreamingResponse(\n                collect_ollama_generate_response(\n                    command.command_id,\n                    self._token_chunk_stream(command.command_id),\n                ),\n                media_type=\"application/json\",\n            )\n\n    async def ollama_tags(self) -> OllamaTagsResponse:\n        \"\"\"Returns list of models in Ollama tags format. We return the downloaded ones only.\"\"\"\n\n        def none_if_empty(value: str) -> str | None:\n            return value or None\n\n        downloaded_model_ids: set[str] = set()\n        for node_downloads in self.state.downloads.values():\n            for dl in node_downloads:\n                if isinstance(dl, DownloadCompleted):\n                    downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)\n\n        cards = [\n            c for c in await get_model_cards() if c.model_id in downloaded_model_ids\n        ]\n\n        now = time.strftime(\"%Y-%m-%dT%H:%M:%SZ\", time.gmtime())\n        return OllamaTagsResponse(\n            models=[\n                OllamaModelTag(\n                    name=str(card.model_id),\n                    model=str(card.model_id),\n                    modified_at=now,\n                    size=card.storage_size.in_bytes,\n                    digest=\"sha256:000000000000\",\n                    details=OllamaModelDetails(\n                        family=none_if_empty(card.family),\n                        quantization_level=none_if_empty(card.quantization),\n                    ),\n                )\n                for card in cards\n            ]\n        )\n\n    async def ollama_show(self, request: Request) -> OllamaShowResponse:\n        \"\"\"Returns model information in Ollama show format.\"\"\"\n        body = await request.body()\n        payload = OllamaShowRequest.model_validate_json(body)\n        model_name = payload.name or payload.model\n        if not model_name:\n            raise HTTPException(status_code=400, detail=\"name or model is required\")\n        try:\n            card = await ModelCard.load(ModelId(model_name))\n        except Exception as exc:\n            raise HTTPException(\n                status_code=404, detail=f\"Model not found: {model_name}\"\n            ) from exc\n\n        return OllamaShowResponse(\n            modelfile=f\"FROM {card.model_id}\",\n            template=\"{{ .Prompt }}\",\n            details=OllamaModelDetails(\n                family=card.family or None,\n                quantization_level=card.quantization or None,\n            ),\n        )\n\n    async def ollama_ps(self) -> OllamaPsResponse:\n        \"\"\"Returns list of running models (active instances).\"\"\"\n        models: list[OllamaPsModel] = []\n        seen: set[str] = set()\n        for instance in self.state.instances.values():\n            model_id = str(instance.shard_assignments.model_id)\n            if model_id in seen:\n                continue\n            seen.add(model_id)\n            models.append(\n                OllamaPsModel(\n                    name=model_id,\n                    model=model_id,\n                    size=0,\n                )\n            )\n        return OllamaPsResponse(models=models)\n\n    async def ollama_version(self) -> dict[str, str]:\n        \"\"\"Returns version information for Ollama API compatibility.\"\"\"\n        return {\"version\": \"exo v1.0\"}\n\n    def _calculate_total_available_memory(self) -> Memory:\n        \"\"\"Calculate total available memory across all nodes in bytes.\"\"\"\n        total_available = Memory()\n\n        for memory in self.state.node_memory.values():\n            total_available += memory.ram_available\n\n        return total_available\n\n    async def get_models(self, status: str | None = Query(default=None)) -> ModelList:\n        \"\"\"Returns list of available models, optionally filtered by being downloaded.\"\"\"\n        cards = await get_model_cards()\n\n        if status == \"downloaded\":\n            downloaded_model_ids: set[str] = set()\n            for node_downloads in self.state.downloads.values():\n                for dl in node_downloads:\n                    if isinstance(dl, DownloadCompleted):\n                        downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)\n            cards = [c for c in cards if c.model_id in downloaded_model_ids]\n\n        return ModelList(\n            data=[\n                ModelListModel(\n                    id=card.model_id,\n                    hugging_face_id=card.model_id,\n                    name=card.model_id.short(),\n                    description=\"\",\n                    tags=[],\n                    storage_size_megabytes=card.storage_size.in_mb,\n                    supports_tensor=card.supports_tensor,\n                    tasks=[task.value for task in card.tasks],\n                    is_custom=is_custom_card(card.model_id),\n                    family=card.family,\n                    quantization=card.quantization,\n                    base_model=card.base_model,\n                    capabilities=card.capabilities,\n                )\n                for card in cards\n            ]\n        )\n\n    async def add_custom_model(self, payload: AddCustomModelParams) -> ModelListModel:\n        \"\"\"Fetch a model from HuggingFace and save as a custom model card.\"\"\"\n        try:\n            card = await ModelCard.fetch_from_hf(payload.model_id)\n        except Exception as exc:\n            raise HTTPException(\n                status_code=400, detail=f\"Failed to fetch model: {exc}\"\n            ) from exc\n\n        return ModelListModel(\n            id=card.model_id,\n            hugging_face_id=card.model_id,\n            name=card.model_id.short(),\n            description=\"\",\n            tags=[],\n            storage_size_megabytes=int(card.storage_size.in_mb),\n            supports_tensor=card.supports_tensor,\n            tasks=[task.value for task in card.tasks],\n            is_custom=True,\n        )\n\n    async def delete_custom_model(self, model_id: ModelId) -> JSONResponse:\n        \"\"\"Delete a user-added custom model card.\"\"\"\n        deleted = await delete_custom_card(model_id)\n        if not deleted:\n            raise HTTPException(status_code=404, detail=\"Custom model card not found\")\n        return JSONResponse(\n            {\"message\": \"Model card deleted\", \"model_id\": str(model_id)}\n        )\n\n    async def search_models(\n        self, query: str = \"\", limit: int = 20\n    ) -> list[HuggingFaceSearchResult]:\n        \"\"\"Search HuggingFace Hub — tries mlx-community first, falls back to all of HuggingFace.\"\"\"\n        from huggingface_hub import ModelInfo, list_models\n\n        def _to_results(models: Iterable[ModelInfo]) -> list[HuggingFaceSearchResult]:\n            return [\n                HuggingFaceSearchResult(\n                    id=m.id,\n                    author=m.author or \"\",\n                    downloads=m.downloads or 0,\n                    likes=m.likes or 0,\n                    last_modified=str(m.last_modified or \"\"),\n                    tags=list(m.tags or []),\n                )\n                for m in models\n            ]\n\n        # Search mlx-community first\n        mlx_results = _to_results(\n            list_models(\n                search=query or None,\n                author=\"mlx-community\",\n                sort=\"downloads\",\n                limit=limit,\n            )\n        )\n        if mlx_results:\n            return mlx_results\n\n        # Fall back to searching all of HuggingFace\n        return _to_results(\n            list_models(\n                search=query or None,\n                sort=\"downloads\",\n                limit=limit,\n            )\n        )\n\n    async def run(self):\n        shutdown_ev = anyio.Event()\n\n        try:\n            async with self._tg as tg:\n                logger.info(\"Starting API\")\n                tg.start_soon(self._apply_state)\n                tg.start_soon(self._pause_on_new_election)\n                tg.start_soon(self._cleanup_expired_images)\n                print_startup_banner(self.port)\n                tg.start_soon(self.run_api, shutdown_ev)\n                try:\n                    await anyio.sleep_forever()\n                finally:\n                    with anyio.CancelScope(shield=True):\n                        shutdown_ev.set()\n        finally:\n            self._event_log.close()\n            self.command_sender.close()\n            self.event_receiver.close()\n\n    async def run_api(self, ev: anyio.Event):\n        cfg = Config()\n        cfg.bind = [f\"0.0.0.0:{self.port}\"]\n        # nb: shared.logging needs updating if any of this changes\n        cfg.accesslog = None\n        cfg.errorlog = \"-\"\n        cfg.logger_class = InterceptLogger\n        with anyio.CancelScope(shield=True):\n            await serve(\n                cast(ASGIFramework, self.app),\n                cfg,\n                shutdown_trigger=ev.wait,\n            )\n\n    async def _apply_state(self):\n        with self.event_receiver as events:\n            async for i_event in events:\n                self._event_log.append(i_event.event)\n                self.state = apply(self.state, i_event)\n                event = i_event.event\n\n                if isinstance(event, ChunkGenerated):\n                    if queue := self._image_generation_queues.get(\n                        event.command_id, None\n                    ):\n                        assert isinstance(event.chunk, ImageChunk)\n                        try:\n                            await queue.send(event.chunk)\n                        except BrokenResourceError:\n                            self._image_generation_queues.pop(event.command_id, None)\n                    if queue := self._text_generation_queues.get(\n                        event.command_id, None\n                    ):\n                        assert not isinstance(event.chunk, ImageChunk)\n                        try:\n                            await queue.send(event.chunk)\n                        except BrokenResourceError:\n                            self._text_generation_queues.pop(event.command_id, None)\n                if isinstance(event, TracesMerged):\n                    self._save_merged_trace(event)\n\n    def _save_merged_trace(self, event: TracesMerged) -> None:\n        traces = [\n            TraceEvent(\n                name=t.name,\n                start_us=t.start_us,\n                duration_us=t.duration_us,\n                rank=t.rank,\n                category=t.category,\n            )\n            for t in event.traces\n        ]\n        output_path = EXO_TRACING_CACHE_DIR / f\"trace_{event.task_id}.json\"\n        export_trace(traces, output_path)\n        logger.debug(f\"Saved merged trace to {output_path}\")\n\n    async def _pause_on_new_election(self):\n        with self.election_receiver as ems:\n            async for message in ems:\n                if message.clock > self.last_completed_election:\n                    self.paused = True\n\n    async def _cleanup_expired_images(self):\n        \"\"\"Periodically clean up expired images from the store.\"\"\"\n        cleanup_interval_seconds = 300  # 5 minutes\n        while True:\n            await anyio.sleep(cleanup_interval_seconds)\n            removed = self._image_store.cleanup_expired()\n            if removed > 0:\n                logger.debug(f\"Cleaned up {removed} expired images\")\n\n    async def _send(self, command: Command):\n        while self.paused:\n            await self.paused_ev.wait()\n        await self.command_sender.send(\n            ForwarderCommand(origin=self._system_id, command=command)\n        )\n\n    async def _send_download(self, command: DownloadCommand):\n        await self.download_command_sender.send(\n            ForwarderDownloadCommand(origin=self._system_id, command=command)\n        )\n\n    async def start_download(\n        self, payload: StartDownloadParams\n    ) -> StartDownloadResponse:\n        command = StartDownload(\n            target_node_id=payload.target_node_id,\n            shard_metadata=payload.shard_metadata,\n        )\n        await self._send_download(command)\n        return StartDownloadResponse(command_id=command.command_id)\n\n    async def delete_download(\n        self, node_id: NodeId, model_id: ModelId\n    ) -> DeleteDownloadResponse:\n        command = DeleteDownload(\n            target_node_id=node_id,\n            model_id=ModelId(model_id),\n        )\n        await self._send_download(command)\n        return DeleteDownloadResponse(command_id=command.command_id)\n\n    @staticmethod\n    def _get_trace_path(task_id: str) -> Path:\n        trace_path = EXO_TRACING_CACHE_DIR / f\"trace_{task_id}.json\"\n        if not trace_path.resolve().is_relative_to(EXO_TRACING_CACHE_DIR.resolve()):\n            raise HTTPException(status_code=400, detail=f\"Invalid task ID: {task_id}\")\n        return trace_path\n\n    async def list_traces(self) -> TraceListResponse:\n        traces: list[TraceListItem] = []\n\n        for trace_file in sorted(\n            EXO_TRACING_CACHE_DIR.glob(\"trace_*.json\"),\n            key=lambda p: p.stat().st_mtime,\n            reverse=True,\n        ):\n            # Extract task_id from filename (trace_{task_id}.json)\n            task_id = trace_file.stem.removeprefix(\"trace_\")\n            stat = trace_file.stat()\n            created_at = datetime.fromtimestamp(\n                stat.st_mtime, tz=timezone.utc\n            ).isoformat()\n            traces.append(\n                TraceListItem(\n                    task_id=task_id,\n                    created_at=created_at,\n                    file_size=stat.st_size,\n                )\n            )\n\n        return TraceListResponse(traces=traces)\n\n    async def get_trace(self, task_id: str) -> TraceResponse:\n        trace_path = self._get_trace_path(task_id)\n\n        if not trace_path.exists():\n            raise HTTPException(status_code=404, detail=f\"Trace not found: {task_id}\")\n\n        trace_events = load_trace_file(trace_path)\n\n        return TraceResponse(\n            task_id=task_id,\n            traces=[\n                TraceEventResponse(\n                    name=event.name,\n                    start_us=event.start_us,\n                    duration_us=event.duration_us,\n                    rank=event.rank,\n                    category=event.category,\n                )\n                for event in trace_events\n            ],\n        )\n\n    async def get_trace_stats(self, task_id: str) -> TraceStatsResponse:\n        trace_path = self._get_trace_path(task_id)\n\n        if not trace_path.exists():\n            raise HTTPException(status_code=404, detail=f\"Trace not found: {task_id}\")\n\n        trace_events = load_trace_file(trace_path)\n        stats = compute_stats(trace_events)\n\n        return TraceStatsResponse(\n            task_id=task_id,\n            total_wall_time_us=stats.total_wall_time_us,\n            by_category={\n                category: TraceCategoryStats(\n                    total_us=cat_stats.total_us,\n                    count=cat_stats.count,\n                    min_us=cat_stats.min_us,\n                    max_us=cat_stats.max_us,\n                    avg_us=cat_stats.avg_us,\n                )\n                for category, cat_stats in stats.by_category.items()\n            },\n            by_rank={\n                rank: TraceRankStats(\n                    by_category={\n                        category: TraceCategoryStats(\n                            total_us=cat_stats.total_us,\n                            count=cat_stats.count,\n                            min_us=cat_stats.min_us,\n                            max_us=cat_stats.max_us,\n                            avg_us=cat_stats.avg_us,\n                        )\n                        for category, cat_stats in rank_stats.items()\n                    }\n                )\n                for rank, rank_stats in stats.by_rank.items()\n            },\n        )\n\n    async def get_trace_raw(self, task_id: str) -> FileResponse:\n        trace_path = self._get_trace_path(task_id)\n\n        if not trace_path.exists():\n            raise HTTPException(status_code=404, detail=f\"Trace not found: {task_id}\")\n\n        return FileResponse(\n            path=trace_path,\n            media_type=\"application/json\",\n            filename=f\"trace_{task_id}.json\",\n        )\n\n    async def delete_traces(self, request: DeleteTracesRequest) -> DeleteTracesResponse:\n        deleted: list[str] = []\n        not_found: list[str] = []\n        for task_id in request.task_ids:\n            trace_path = self._get_trace_path(task_id)\n            if trace_path.exists():\n                trace_path.unlink()\n                deleted.append(task_id)\n            else:\n                not_found.append(task_id)\n        return DeleteTracesResponse(deleted=deleted, not_found=not_found)\n\n    async def get_onboarding(self) -> JSONResponse:\n        return JSONResponse({\"completed\": ONBOARDING_COMPLETE_FILE.exists()})\n\n    async def complete_onboarding(self) -> JSONResponse:\n        ONBOARDING_COMPLETE_FILE.parent.mkdir(parents=True, exist_ok=True)\n        ONBOARDING_COMPLETE_FILE.write_text(\"true\")\n        return JSONResponse({\"completed\": True})\n"
  },
  {
    "path": "src/exo/api/tests/test_api_error_handling.py",
    "content": "# pyright: reportUnusedFunction=false, reportAny=false\nfrom typing import Any\n\nfrom fastapi import FastAPI, HTTPException\nfrom fastapi.testclient import TestClient\n\nfrom exo.api.main import API\n\n\ndef test_http_exception_handler_formats_openai_style() -> None:\n    \"\"\"Test that HTTPException is converted to OpenAI-style error format.\"\"\"\n\n    app = FastAPI()\n\n    # Setup exception handler\n    api = object.__new__(API)\n    api.app = app\n    api._setup_exception_handlers()  # pyright: ignore[reportPrivateUsage]\n\n    # Add test routes that raise HTTPException\n    @app.get(\"/test-error\")\n    async def _test_error() -> None:\n        raise HTTPException(status_code=500, detail=\"Test error message\")\n\n    @app.get(\"/test-not-found\")\n    async def _test_not_found() -> None:\n        raise HTTPException(status_code=404, detail=\"Resource not found\")\n\n    client = TestClient(app)\n\n    # Test 500 error\n    response = client.get(\"/test-error\")\n    assert response.status_code == 500\n    data: dict[str, Any] = response.json()\n    assert \"error\" in data\n    assert data[\"error\"][\"message\"] == \"Test error message\"\n    assert data[\"error\"][\"type\"] == \"Internal Server Error\"\n    assert data[\"error\"][\"code\"] == 500\n\n    # Test 404 error\n    response = client.get(\"/test-not-found\")\n    assert response.status_code == 404\n    data = response.json()\n    assert \"error\" in data\n    assert data[\"error\"][\"message\"] == \"Resource not found\"\n    assert data[\"error\"][\"type\"] == \"Not Found\"\n    assert data[\"error\"][\"code\"] == 404\n"
  },
  {
    "path": "src/exo/api/tests/test_cancel_command.py",
    "content": "# pyright: reportUnusedFunction=false, reportAny=false\nfrom typing import Any\nfrom unittest.mock import AsyncMock, MagicMock\n\nfrom fastapi import FastAPI\nfrom fastapi.testclient import TestClient\n\nfrom exo.api.main import API\nfrom exo.shared.types.common import CommandId\n\n\ndef _make_api() -> Any:\n    \"\"\"Create a minimal API instance with cancel route and error handler.\"\"\"\n\n    app = FastAPI()\n    api = object.__new__(API)\n    api.app = app\n    api._text_generation_queues = {}  # pyright: ignore[reportPrivateUsage]\n    api._image_generation_queues = {}  # pyright: ignore[reportPrivateUsage]\n    api._send = AsyncMock()  # pyright: ignore[reportPrivateUsage]\n    api._setup_exception_handlers()  # pyright: ignore[reportPrivateUsage]\n    app.post(\"/v1/cancel/{command_id}\")(api.cancel_command)\n    return api\n\n\ndef test_cancel_nonexistent_command_returns_404() -> None:\n    \"\"\"Cancel for an unknown command_id returns 404 in OpenAI error format.\"\"\"\n    api = _make_api()\n    client = TestClient(api.app)\n\n    response = client.post(\"/v1/cancel/nonexistent-id\")\n    assert response.status_code == 404\n    data: dict[str, Any] = response.json()\n    assert \"error\" in data\n    assert data[\"error\"][\"message\"] == \"Command not found or already completed\"\n    assert data[\"error\"][\"type\"] == \"Not Found\"\n    assert data[\"error\"][\"code\"] == 404\n\n\ndef test_cancel_active_text_generation() -> None:\n    \"\"\"Cancel an active text generation command: returns 200, sender.close() called.\"\"\"\n    api = _make_api()\n    client = TestClient(api.app)\n\n    cid = CommandId(\"text-cmd-123\")\n    sender = MagicMock()\n    api._text_generation_queues[cid] = sender\n\n    response = client.post(f\"/v1/cancel/{cid}\")\n    assert response.status_code == 200\n    data: dict[str, Any] = response.json()\n    assert data[\"message\"] == \"Command cancelled.\"\n    assert data[\"command_id\"] == str(cid)\n    sender.close.assert_called_once()\n    api._send.assert_called_once()\n    task_cancelled = api._send.call_args[0][0]\n    assert task_cancelled.cancelled_command_id == cid\n\n\ndef test_cancel_active_image_generation() -> None:\n    \"\"\"Cancel an active image generation command: returns 200, sender.close() called.\"\"\"\n    api = _make_api()\n    client = TestClient(api.app)\n\n    cid = CommandId(\"img-cmd-456\")\n    sender = MagicMock()\n    api._image_generation_queues[cid] = sender\n\n    response = client.post(f\"/v1/cancel/{cid}\")\n    assert response.status_code == 200\n    data: dict[str, Any] = response.json()\n    assert data[\"message\"] == \"Command cancelled.\"\n    assert data[\"command_id\"] == str(cid)\n    sender.close.assert_called_once()\n    api._send.assert_called_once()\n    task_cancelled = api._send.call_args[0][0]\n    assert task_cancelled.cancelled_command_id == cid\n"
  },
  {
    "path": "src/exo/api/tests/test_claude_api.py",
    "content": "\"\"\"Tests for Claude Messages API conversion functions and types.\"\"\"\n\nimport pydantic\nimport pytest\n\nfrom exo.api.adapters.claude import (\n    claude_request_to_text_generation,\n    finish_reason_to_claude_stop_reason,\n)\nfrom exo.api.types.claude_api import (\n    ClaudeMessage,\n    ClaudeMessagesRequest,\n    ClaudeTextBlock,\n)\nfrom exo.shared.types.common import ModelId\n\n\nclass TestFinishReasonToClaudeStopReason:\n    \"\"\"Tests for finish_reason to Claude stop_reason mapping.\"\"\"\n\n    def test_stop_maps_to_end_turn(self):\n        assert finish_reason_to_claude_stop_reason(\"stop\") == \"end_turn\"\n\n    def test_length_maps_to_max_tokens(self):\n        assert finish_reason_to_claude_stop_reason(\"length\") == \"max_tokens\"\n\n    def test_tool_calls_maps_to_tool_use(self):\n        assert finish_reason_to_claude_stop_reason(\"tool_calls\") == \"tool_use\"\n\n    def test_function_call_maps_to_tool_use(self):\n        assert finish_reason_to_claude_stop_reason(\"function_call\") == \"tool_use\"\n\n    def test_content_filter_maps_to_end_turn(self):\n        assert finish_reason_to_claude_stop_reason(\"content_filter\") == \"end_turn\"\n\n    def test_none_returns_none(self):\n        assert finish_reason_to_claude_stop_reason(None) is None\n\n\nclass TestClaudeRequestToInternal:\n    \"\"\"Tests for converting Claude Messages API requests to TextGenerationTaskParams.\"\"\"\n\n    def test_basic_request_conversion(self):\n        request = ClaudeMessagesRequest(\n            model=ModelId(\"claude-3-opus\"),\n            max_tokens=100,\n            messages=[\n                ClaudeMessage(role=\"user\", content=\"Hello\"),\n            ],\n        )\n        params = claude_request_to_text_generation(request)\n\n        assert params.model == \"claude-3-opus\"\n        assert params.max_output_tokens == 100\n        assert isinstance(params.input, list)\n        assert len(params.input) == 1\n        assert params.input[0].role == \"user\"\n        assert params.input[0].content == \"Hello\"\n        assert params.instructions is None\n\n    def test_request_with_system_string(self):\n        request = ClaudeMessagesRequest(\n            model=ModelId(\"claude-3-opus\"),\n            max_tokens=100,\n            system=\"You are a helpful assistant.\",\n            messages=[\n                ClaudeMessage(role=\"user\", content=\"Hello\"),\n            ],\n        )\n        params = claude_request_to_text_generation(request)\n\n        assert params.instructions == \"You are a helpful assistant.\"\n        assert isinstance(params.input, list)\n        assert len(params.input) == 1\n        assert params.input[0].role == \"user\"\n        assert params.input[0].content == \"Hello\"\n\n    def test_request_with_system_text_blocks(self):\n        request = ClaudeMessagesRequest(\n            model=ModelId(\"claude-3-opus\"),\n            max_tokens=100,\n            system=[\n                ClaudeTextBlock(text=\"You are helpful. \"),\n                ClaudeTextBlock(text=\"Be concise.\"),\n            ],\n            messages=[\n                ClaudeMessage(role=\"user\", content=\"Hello\"),\n            ],\n        )\n        params = claude_request_to_text_generation(request)\n\n        assert params.instructions == \"You are helpful. Be concise.\"\n        assert isinstance(params.input, list)\n        assert len(params.input) == 1\n\n    def test_request_with_content_blocks(self):\n        request = ClaudeMessagesRequest(\n            model=ModelId(\"claude-3-opus\"),\n            max_tokens=100,\n            messages=[\n                ClaudeMessage(\n                    role=\"user\",\n                    content=[\n                        ClaudeTextBlock(text=\"First part. \"),\n                        ClaudeTextBlock(text=\"Second part.\"),\n                    ],\n                ),\n            ],\n        )\n        params = claude_request_to_text_generation(request)\n\n        assert isinstance(params.input, list)\n        assert len(params.input) == 1\n        assert params.input[0].content == \"First part. Second part.\"\n\n    def test_request_with_multi_turn_conversation(self):\n        request = ClaudeMessagesRequest(\n            model=ModelId(\"claude-3-opus\"),\n            max_tokens=100,\n            messages=[\n                ClaudeMessage(role=\"user\", content=\"Hello\"),\n                ClaudeMessage(role=\"assistant\", content=\"Hi there!\"),\n                ClaudeMessage(role=\"user\", content=\"How are you?\"),\n            ],\n        )\n        params = claude_request_to_text_generation(request)\n\n        assert isinstance(params.input, list)\n        assert len(params.input) == 3\n        assert params.input[0].role == \"user\"\n        assert params.input[1].role == \"assistant\"\n        assert params.input[2].role == \"user\"\n\n    def test_request_with_optional_parameters(self):\n        request = ClaudeMessagesRequest(\n            model=ModelId(\"claude-3-opus\"),\n            max_tokens=100,\n            messages=[ClaudeMessage(role=\"user\", content=\"Hello\")],\n            temperature=0.7,\n            top_p=0.9,\n            top_k=40,\n            stop_sequences=[\"STOP\", \"END\"],\n            stream=True,\n        )\n        params = claude_request_to_text_generation(request)\n\n        assert params.temperature == 0.7\n        assert params.top_p == 0.9\n        assert params.top_k == 40\n        assert params.stop == [\"STOP\", \"END\"]\n        assert params.stream is True\n\n\nclass TestClaudeMessagesRequestValidation:\n    \"\"\"Tests for Claude Messages API request validation.\"\"\"\n\n    def test_request_requires_model(self):\n        with pytest.raises(pydantic.ValidationError):\n            ClaudeMessagesRequest.model_validate(\n                {\n                    \"max_tokens\": 100,\n                    \"messages\": [{\"role\": \"user\", \"content\": \"Hello\"}],\n                }\n            )\n\n    def test_request_requires_max_tokens(self):\n        with pytest.raises(pydantic.ValidationError):\n            ClaudeMessagesRequest.model_validate(\n                {\n                    \"model\": \"claude-3-opus\",\n                    \"messages\": [{\"role\": \"user\", \"content\": \"Hello\"}],\n                }\n            )\n\n    def test_request_requires_messages(self):\n        with pytest.raises(pydantic.ValidationError):\n            ClaudeMessagesRequest.model_validate(\n                {\n                    \"model\": \"claude-3-opus\",\n                    \"max_tokens\": 100,\n                }\n            )\n"
  },
  {
    "path": "src/exo/api/tests/test_claude_tool_use.py",
    "content": "\"\"\"Tests for Claude Messages API tool_use support in the adapter.\"\"\"\n\nimport json\nfrom collections.abc import AsyncGenerator\nfrom typing import Any, cast\n\nfrom exo.api.adapters.claude import (\n    ClaudeMessagesResponse,\n    collect_claude_response,\n    generate_claude_stream,\n)\nfrom exo.api.types import ToolCallItem\nfrom exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk\nfrom exo.shared.types.common import CommandId, ModelId\n\n\nasync def _chunks_to_stream(\n    chunks: list[ErrorChunk | ToolCallChunk | TokenChunk],\n) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:\n    for chunk in chunks:\n        yield chunk\n\n\nasync def _collect_response(\n    command_id: CommandId,\n    model: str,\n    chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],\n) -> ClaudeMessagesResponse:\n    \"\"\"Helper to consume the async generator and parse the JSON response.\"\"\"\n    parts: list[str] = []\n    async for part in collect_claude_response(command_id, model, chunk_stream):\n        parts.append(part)\n    return ClaudeMessagesResponse.model_validate_json(\"\".join(parts))\n\n\nMODEL = ModelId(\"test-model\")\nCOMMAND_ID = CommandId(\"cmd_test123\")\n\n\ndef _parse_sse_events(events: list[str]) -> list[dict[str, Any]]:\n    \"\"\"Parse SSE event strings into JSON dicts.\"\"\"\n    parsed: list[dict[str, Any]] = []\n    for event_str in events:\n        for line in event_str.strip().split(\"\\n\"):\n            if line.startswith(\"data: \"):\n                parsed.append(cast(dict[str, Any], json.loads(line[6:])))\n    return parsed\n\n\nclass TestCollectClaudeResponseToolUse:\n    \"\"\"Tests for non-streaming tool_use response collection.\"\"\"\n\n    async def test_tool_call_chunk_produces_tool_use_blocks(self):\n        chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [\n            ToolCallChunk(\n                model=MODEL,\n                usage=None,\n                tool_calls=[\n                    ToolCallItem(\n                        name=\"get_weather\",\n                        arguments='{\"location\": \"San Francisco\"}',\n                    )\n                ],\n            ),\n        ]\n        response = await _collect_response(\n            COMMAND_ID, \"test-model\", _chunks_to_stream(chunks)\n        )\n\n        assert response.stop_reason == \"tool_use\"\n        tool_blocks = [b for b in response.content if b.type == \"tool_use\"]\n        assert len(tool_blocks) == 1\n        block = tool_blocks[0]\n        assert block.type == \"tool_use\"\n        assert block.name == \"get_weather\"\n        assert block.input == {\"location\": \"San Francisco\"}\n        assert block.id.startswith(\"toolu_\")\n\n    async def test_multiple_tool_calls(self):\n        chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [\n            ToolCallChunk(\n                model=MODEL,\n                usage=None,\n                tool_calls=[\n                    ToolCallItem(\n                        name=\"get_weather\",\n                        arguments='{\"location\": \"SF\"}',\n                    ),\n                    ToolCallItem(\n                        name=\"get_time\",\n                        arguments='{\"timezone\": \"PST\"}',\n                    ),\n                ],\n            ),\n        ]\n        response = await _collect_response(\n            COMMAND_ID, \"test-model\", _chunks_to_stream(chunks)\n        )\n\n        assert response.stop_reason == \"tool_use\"\n        tool_blocks = [b for b in response.content if b.type == \"tool_use\"]\n        assert len(tool_blocks) == 2\n        assert tool_blocks[0].name == \"get_weather\"\n        assert tool_blocks[1].name == \"get_time\"\n\n    async def test_mixed_text_and_tool_use(self):\n        chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [\n            TokenChunk(model=MODEL, text=\"Let me check \", token_id=1, usage=None),\n            TokenChunk(model=MODEL, text=\"the weather.\", token_id=2, usage=None),\n            ToolCallChunk(\n                model=MODEL,\n                usage=None,\n                tool_calls=[\n                    ToolCallItem(\n                        name=\"get_weather\",\n                        arguments='{\"location\": \"NYC\"}',\n                    )\n                ],\n            ),\n        ]\n        response = await _collect_response(\n            COMMAND_ID, \"test-model\", _chunks_to_stream(chunks)\n        )\n\n        assert response.stop_reason == \"tool_use\"\n        text_blocks = [b for b in response.content if b.type == \"text\"]\n        tool_blocks = [b for b in response.content if b.type == \"tool_use\"]\n        assert len(text_blocks) == 1\n        assert text_blocks[0].text == \"Let me check the weather.\"\n        assert len(tool_blocks) == 1\n        assert tool_blocks[0].name == \"get_weather\"\n\n    async def test_no_content_produces_empty_text_block(self):\n        chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = []\n        response = await _collect_response(\n            COMMAND_ID, \"test-model\", _chunks_to_stream(chunks)\n        )\n        assert len(response.content) == 1\n        assert response.content[0].type == \"text\"\n\n\nclass TestGenerateClaudeStreamToolUse:\n    \"\"\"Tests for streaming tool_use event generation.\"\"\"\n\n    async def test_tool_call_emits_tool_use_events(self):\n        chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [\n            ToolCallChunk(\n                model=MODEL,\n                usage=None,\n                tool_calls=[\n                    ToolCallItem(\n                        name=\"get_weather\",\n                        arguments='{\"location\": \"SF\"}',\n                    )\n                ],\n            ),\n        ]\n        events: list[str] = []\n        async for event in generate_claude_stream(\n            COMMAND_ID, \"test-model\", _chunks_to_stream(chunks)\n        ):\n            events.append(event)\n\n        parsed = _parse_sse_events(events)\n\n        # Find tool_use content_block_start\n        tool_starts = [\n            e\n            for e in parsed\n            if e.get(\"type\") == \"content_block_start\"\n            and cast(dict[str, Any], e.get(\"content_block\", {})).get(\"type\")\n            == \"tool_use\"\n        ]\n        assert len(tool_starts) == 1\n        content_block = cast(dict[str, Any], tool_starts[0][\"content_block\"])\n        assert content_block[\"name\"] == \"get_weather\"\n        assert content_block[\"input\"] == {}\n        assert cast(str, content_block[\"id\"]).startswith(\"toolu_\")\n\n        # Find input_json_delta\n        json_deltas = [\n            e\n            for e in parsed\n            if e.get(\"type\") == \"content_block_delta\"\n            and cast(dict[str, Any], e.get(\"delta\", {})).get(\"type\")\n            == \"input_json_delta\"\n        ]\n        assert len(json_deltas) == 1\n        delta = cast(dict[str, Any], json_deltas[0][\"delta\"])\n        assert json.loads(cast(str, delta[\"partial_json\"])) == {\"location\": \"SF\"}\n\n        # Find message_delta with tool_use stop reason\n        msg_deltas = [e for e in parsed if e.get(\"type\") == \"message_delta\"]\n        assert len(msg_deltas) == 1\n        assert cast(dict[str, Any], msg_deltas[0][\"delta\"])[\"stop_reason\"] == \"tool_use\"\n\n    async def test_streaming_mixed_text_and_tool_use(self):\n        chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [\n            TokenChunk(model=MODEL, text=\"Hello \", token_id=1, usage=None),\n            ToolCallChunk(\n                model=MODEL,\n                usage=None,\n                tool_calls=[\n                    ToolCallItem(\n                        name=\"search\",\n                        arguments='{\"query\": \"test\"}',\n                    )\n                ],\n            ),\n        ]\n        events: list[str] = []\n        async for event in generate_claude_stream(\n            COMMAND_ID, \"test-model\", _chunks_to_stream(chunks)\n        ):\n            events.append(event)\n\n        parsed = _parse_sse_events(events)\n\n        # Should have text delta at index 0\n        text_deltas = [\n            e\n            for e in parsed\n            if e.get(\"type\") == \"content_block_delta\"\n            and cast(dict[str, Any], e.get(\"delta\", {})).get(\"type\") == \"text_delta\"\n        ]\n        assert len(text_deltas) == 1\n        assert text_deltas[0][\"index\"] == 0\n        assert cast(dict[str, Any], text_deltas[0][\"delta\"])[\"text\"] == \"Hello \"\n\n        # Tool block at index 1\n        tool_starts = [\n            e\n            for e in parsed\n            if e.get(\"type\") == \"content_block_start\"\n            and cast(dict[str, Any], e.get(\"content_block\", {})).get(\"type\")\n            == \"tool_use\"\n        ]\n        assert len(tool_starts) == 1\n        assert tool_starts[0][\"index\"] == 1\n\n        # Stop reason should be tool_use\n        msg_deltas = [e for e in parsed if e.get(\"type\") == \"message_delta\"]\n        assert cast(dict[str, Any], msg_deltas[0][\"delta\"])[\"stop_reason\"] == \"tool_use\"\n\n    async def test_streaming_tool_block_stop_events(self):\n        chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [\n            ToolCallChunk(\n                model=MODEL,\n                usage=None,\n                tool_calls=[\n                    ToolCallItem(name=\"fn1\", arguments=\"{}\"),\n                    ToolCallItem(name=\"fn2\", arguments='{\"a\": 1}'),\n                ],\n            ),\n        ]\n        events: list[str] = []\n        async for event in generate_claude_stream(\n            COMMAND_ID, \"test-model\", _chunks_to_stream(chunks)\n        ):\n            events.append(event)\n\n        parsed = _parse_sse_events(events)\n\n        # Two tool block starts (at indices 0 and 1 — no text block when only tools)\n        tool_starts = [\n            e\n            for e in parsed\n            if e.get(\"type\") == \"content_block_start\"\n            and cast(dict[str, Any], e.get(\"content_block\", {})).get(\"type\")\n            == \"tool_use\"\n        ]\n        assert len(tool_starts) == 2\n        assert tool_starts[0][\"index\"] == 0\n        assert tool_starts[1][\"index\"] == 1\n\n        # Two tool block stops (at indices 0 and 1)\n        block_stops = [e for e in parsed if e.get(\"type\") == \"content_block_stop\"]\n        stop_indices = [e[\"index\"] for e in block_stops]\n        assert 0 in stop_indices\n        assert 1 in stop_indices\n"
  },
  {
    "path": "src/exo/api/tests/test_openai_responses_api.py",
    "content": "\"\"\"Tests for OpenAI Responses API wire types.\n\nResponsesRequest is the API wire type for the Responses endpoint.\nThe responses adapter converts it to TextGenerationTaskParams for the pipeline.\n\"\"\"\n\nimport pydantic\nimport pytest\n\nfrom exo.api.types.openai_responses import (\n    ResponseInputMessage,\n    ResponsesRequest,\n)\nfrom exo.shared.types.common import ModelId\n\n\nclass TestResponsesRequestValidation:\n    \"\"\"Tests for OpenAI Responses API request validation.\"\"\"\n\n    def test_request_requires_model(self):\n        with pytest.raises(pydantic.ValidationError):\n            ResponsesRequest.model_validate(\n                {\n                    \"input\": \"Hello\",\n                }\n            )\n\n    def test_request_requires_input(self):\n        with pytest.raises(pydantic.ValidationError):\n            ResponsesRequest.model_validate(\n                {\n                    \"model\": \"gpt-4o\",\n                }\n            )\n\n    def test_request_accepts_string_input(self):\n        request = ResponsesRequest(\n            model=ModelId(\"gpt-4o\"),\n            input=\"Hello\",\n        )\n        assert request.input == \"Hello\"\n\n    def test_request_accepts_message_array_input(self):\n        request = ResponsesRequest(\n            model=ModelId(\"gpt-4o\"),\n            input=[ResponseInputMessage(role=\"user\", content=\"Hello\")],\n        )\n        assert len(request.input) == 1\n"
  },
  {
    "path": "src/exo/api/types/__init__.py",
    "content": "from .api import AddCustomModelParams as AddCustomModelParams\nfrom .api import AdvancedImageParams as AdvancedImageParams\nfrom .api import BenchChatCompletionRequest as BenchChatCompletionRequest\nfrom .api import BenchChatCompletionResponse as BenchChatCompletionResponse\nfrom .api import BenchImageGenerationResponse as BenchImageGenerationResponse\nfrom .api import BenchImageGenerationTaskParams as BenchImageGenerationTaskParams\nfrom .api import CancelCommandResponse as CancelCommandResponse\nfrom .api import ChatCompletionChoice as ChatCompletionChoice\nfrom .api import ChatCompletionMessage as ChatCompletionMessage\nfrom .api import ChatCompletionMessageText as ChatCompletionMessageText\nfrom .api import ChatCompletionRequest as ChatCompletionRequest\nfrom .api import ChatCompletionResponse as ChatCompletionResponse\nfrom .api import CompletionTokensDetails as CompletionTokensDetails\nfrom .api import CreateInstanceParams as CreateInstanceParams\nfrom .api import CreateInstanceResponse as CreateInstanceResponse\nfrom .api import DeleteDownloadResponse as DeleteDownloadResponse\nfrom .api import DeleteInstanceResponse as DeleteInstanceResponse\nfrom .api import DeleteTracesRequest as DeleteTracesRequest\nfrom .api import DeleteTracesResponse as DeleteTracesResponse\nfrom .api import ErrorInfo as ErrorInfo\nfrom .api import ErrorResponse as ErrorResponse\nfrom .api import FinishReason as FinishReason\nfrom .api import GenerationStats as GenerationStats\nfrom .api import HuggingFaceSearchResult as HuggingFaceSearchResult\nfrom .api import ImageData as ImageData\nfrom .api import ImageEditsTaskParams as ImageEditsTaskParams\nfrom .api import ImageGenerationResponse as ImageGenerationResponse\nfrom .api import ImageGenerationStats as ImageGenerationStats\nfrom .api import ImageGenerationTaskParams as ImageGenerationTaskParams\nfrom .api import ImageListItem as ImageListItem\nfrom .api import ImageListResponse as ImageListResponse\nfrom .api import ImageSize as ImageSize\nfrom .api import Logprobs as Logprobs\nfrom .api import LogprobsContentItem as LogprobsContentItem\nfrom .api import ModelList as ModelList\nfrom .api import ModelListModel as ModelListModel\nfrom .api import NodePowerStats as NodePowerStats\nfrom .api import PlaceInstanceParams as PlaceInstanceParams\nfrom .api import PlacementPreview as PlacementPreview\nfrom .api import PlacementPreviewResponse as PlacementPreviewResponse\nfrom .api import PowerUsage as PowerUsage\nfrom .api import PromptTokensDetails as PromptTokensDetails\nfrom .api import StartDownloadParams as StartDownloadParams\nfrom .api import StartDownloadResponse as StartDownloadResponse\nfrom .api import StreamingChoiceResponse as StreamingChoiceResponse\nfrom .api import ToolCall as ToolCall\nfrom .api import ToolCallItem as ToolCallItem\nfrom .api import TopLogprobItem as TopLogprobItem\nfrom .api import TraceCategoryStats as TraceCategoryStats\nfrom .api import TraceEventResponse as TraceEventResponse\nfrom .api import TraceListItem as TraceListItem\nfrom .api import TraceListResponse as TraceListResponse\nfrom .api import TraceRankStats as TraceRankStats\nfrom .api import TraceResponse as TraceResponse\nfrom .api import TraceStatsResponse as TraceStatsResponse\nfrom .api import Usage as Usage\nfrom .api import normalize_image_size as normalize_image_size\n"
  },
  {
    "path": "src/exo/api/types/api.py",
    "content": "import time\nfrom collections.abc import Generator\nfrom typing import Annotated, Any, Literal, get_args\nfrom uuid import uuid4\n\nfrom pydantic import BaseModel, Field, field_validator\n\nfrom exo.shared.models.model_cards import ModelCard, ModelId\nfrom exo.shared.types.common import CommandId, NodeId\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.text_generation import ReasoningEffort\nfrom exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta\nfrom exo.shared.types.worker.shards import Sharding, ShardMetadata\nfrom exo.utils.pydantic_ext import CamelCaseModel\n\nFinishReason = Literal[\n    \"stop\", \"length\", \"tool_calls\", \"content_filter\", \"function_call\", \"error\"\n]\n\n\nclass ErrorInfo(BaseModel):\n    message: str\n    type: str\n    param: str | None = None\n    code: int\n\n\nclass ErrorResponse(BaseModel):\n    error: ErrorInfo\n\n\nclass ModelListModel(BaseModel):\n    id: str\n    object: str = \"model\"\n    created: int = Field(default_factory=lambda: int(time.time()))\n    owned_by: str = \"exo\"\n    # openwebui fields\n    hugging_face_id: str = Field(default=\"\")\n    name: str = Field(default=\"\")\n    description: str = Field(default=\"\")\n    context_length: int = Field(default=0)\n    tags: list[str] = Field(default=[])\n    storage_size_megabytes: int = Field(default=0)\n    supports_tensor: bool = Field(default=False)\n    tasks: list[str] = Field(default=[])\n    is_custom: bool = Field(default=False)\n    family: str = Field(default=\"\")\n    quantization: str = Field(default=\"\")\n    base_model: str = Field(default=\"\")\n    capabilities: list[str] = Field(default_factory=list)\n\n\nclass ModelList(BaseModel):\n    object: Literal[\"list\"] = \"list\"\n    data: list[ModelListModel]\n\n\nclass ChatCompletionMessageText(BaseModel):\n    type: Literal[\"text\"] = \"text\"\n    text: str\n\n\nclass ToolCallItem(BaseModel):\n    id: str = Field(default_factory=lambda: str(uuid4()))\n    name: str\n    arguments: str\n\n\nclass ToolCall(BaseModel):\n    id: str\n    index: int | None = None\n    type: Literal[\"function\"] = \"function\"\n    function: ToolCallItem\n\n\nclass ChatCompletionMessage(BaseModel):\n    role: Literal[\"system\", \"user\", \"assistant\", \"developer\", \"tool\", \"function\"]\n    content: (\n        str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None\n    ) = None\n    reasoning_content: str | None = None\n    name: str | None = None\n    tool_calls: list[ToolCall] | None = None\n    tool_call_id: str | None = None\n    function_call: dict[str, Any] | None = None\n\n\nclass BenchChatCompletionMessage(ChatCompletionMessage):\n    pass\n\n\nclass TopLogprobItem(BaseModel):\n    token: str\n    logprob: float\n    bytes: list[int] | None = None\n\n\nclass LogprobsContentItem(BaseModel):\n    token: str\n    logprob: float\n    bytes: list[int] | None = None\n    top_logprobs: list[TopLogprobItem]\n\n\nclass Logprobs(BaseModel):\n    content: list[LogprobsContentItem] | None = None\n\n\nclass PromptTokensDetails(BaseModel):\n    cached_tokens: int = 0\n    audio_tokens: int = 0\n\n\nclass CompletionTokensDetails(BaseModel):\n    reasoning_tokens: int = 0\n    audio_tokens: int = 0\n    accepted_prediction_tokens: int = 0\n    rejected_prediction_tokens: int = 0\n\n\nclass Usage(BaseModel):\n    prompt_tokens: int\n    completion_tokens: int\n    total_tokens: int\n    prompt_tokens_details: PromptTokensDetails\n    completion_tokens_details: CompletionTokensDetails\n\n\nclass StreamingChoiceResponse(BaseModel):\n    index: int\n    delta: ChatCompletionMessage\n    logprobs: Logprobs | None = None\n    finish_reason: FinishReason | None = None\n    usage: Usage | None = None\n\n\nclass ChatCompletionChoice(BaseModel):\n    index: int\n    message: ChatCompletionMessage\n    logprobs: Logprobs | None = None\n    finish_reason: FinishReason | None = None\n\n\nclass ChatCompletionResponse(BaseModel):\n    id: str\n    object: Literal[\"chat.completion\"] = \"chat.completion\"\n    created: int\n    model: str\n    choices: list[ChatCompletionChoice | StreamingChoiceResponse]\n    usage: Usage | None = None\n    service_tier: str | None = None\n\n\nclass GenerationStats(BaseModel):\n    prompt_tps: float\n    generation_tps: float\n    prompt_tokens: int\n    generation_tokens: int\n    peak_memory_usage: Memory\n\n\nclass ImageGenerationStats(BaseModel):\n    seconds_per_step: float\n    total_generation_time: float\n\n    num_inference_steps: int\n    num_images: int\n\n    image_width: int\n    image_height: int\n\n    peak_memory_usage: Memory\n\n\nclass NodePowerStats(BaseModel, frozen=True):\n    node_id: NodeId\n    samples: int\n    avg_sys_power: float\n\n\nclass PowerUsage(BaseModel, frozen=True):\n    elapsed_seconds: float\n    nodes: list[NodePowerStats]\n    total_avg_sys_power_watts: float\n    total_energy_joules: float\n\n\nclass BenchChatCompletionResponse(ChatCompletionResponse):\n    generation_stats: GenerationStats | None = None\n    power_usage: PowerUsage | None = None\n\n\nclass StreamOptions(BaseModel):\n    include_usage: bool = False\n\n\nclass ChatCompletionRequest(BaseModel):\n    model: ModelId\n    frequency_penalty: float | None = None\n    messages: list[ChatCompletionMessage]\n    logit_bias: dict[str, int] | None = None\n    logprobs: bool | None = None\n    top_logprobs: int | None = None\n    max_tokens: int | None = None\n    n: int | None = None\n    presence_penalty: float | None = None\n    response_format: dict[str, Any] | None = None\n    seed: int | None = None\n    stop: str | list[str] | None = None\n    stream: bool = False\n    stream_options: StreamOptions | None = None\n    temperature: float | None = None\n    top_p: float | None = None\n    top_k: int | None = None\n    tools: list[dict[str, Any]] | None = None\n    reasoning_effort: ReasoningEffort | None = None\n    enable_thinking: bool | None = None\n    min_p: float | None = None\n    repetition_penalty: float | None = None\n    repetition_context_size: int | None = None\n    tool_choice: str | dict[str, Any] | None = None\n    parallel_tool_calls: bool | None = None\n    user: str | None = None\n\n\nclass BenchChatCompletionRequest(ChatCompletionRequest):\n    pass\n\n\nclass AddCustomModelParams(BaseModel):\n    model_id: ModelId\n\n\nclass HuggingFaceSearchResult(BaseModel):\n    id: str\n    author: str = \"\"\n    downloads: int = 0\n    likes: int = 0\n    last_modified: str = \"\"\n    tags: list[str] = Field(default_factory=list)\n\n\nclass PlaceInstanceParams(BaseModel):\n    model_id: ModelId\n    sharding: Sharding = Sharding.Pipeline\n    instance_meta: InstanceMeta = InstanceMeta.MlxRing\n    min_nodes: int = 1\n\n\nclass CreateInstanceParams(BaseModel):\n    instance: Instance\n\n\nclass PlacementPreview(BaseModel):\n    model_id: ModelId\n    sharding: Sharding\n    instance_meta: InstanceMeta\n    instance: Instance | None = None\n    # Keys are NodeId strings, values are additional bytes that would be used on that node\n    memory_delta_by_node: dict[str, int] | None = None\n    error: str | None = None\n\n\nclass PlacementPreviewResponse(BaseModel):\n    previews: list[PlacementPreview]\n\n\nclass DeleteInstanceTaskParams(BaseModel):\n    instance_id: str\n\n\nclass CreateInstanceResponse(BaseModel):\n    message: str\n    command_id: CommandId\n    model_card: ModelCard\n\n\nclass DeleteInstanceResponse(BaseModel):\n    message: str\n    command_id: CommandId\n    instance_id: InstanceId\n\n\nclass CancelCommandResponse(BaseModel):\n    message: str\n    command_id: CommandId\n\n\nImageSize = Literal[\n    \"auto\",\n    \"512x512\",\n    \"768x768\",\n    \"1024x768\",\n    \"768x1024\",\n    \"1024x1024\",\n    \"1024x1536\",\n    \"1536x1024\",\n]\n\n\ndef normalize_image_size(v: object) -> ImageSize:\n    \"\"\"Shared validator for ImageSize fields: maps None → \"auto\" and rejects invalid values.\"\"\"\n    if v is None:\n        return \"auto\"\n    if v not in get_args(ImageSize):\n        raise ValueError(f\"Invalid size: {v!r}. Must be one of {get_args(ImageSize)}\")\n    return v  # pyright: ignore[reportReturnType]\n\n\nclass AdvancedImageParams(BaseModel):\n    seed: Annotated[int, Field(ge=0)] | None = None\n    num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None\n    guidance: Annotated[float, Field(ge=1.0, le=20.0)] | None = None\n    negative_prompt: str | None = None\n    num_sync_steps: Annotated[int, Field(ge=1, le=100)] | None = None\n\n\nclass ImageGenerationTaskParams(BaseModel):\n    prompt: str\n    background: str | None = None\n    model: str\n    moderation: str | None = None\n    n: int | None = 1\n    output_compression: int | None = None\n    output_format: Literal[\"png\", \"jpeg\", \"webp\"] = \"png\"\n    partial_images: int | None = 0\n    quality: Literal[\"high\", \"medium\", \"low\"] | None = \"medium\"\n    response_format: Literal[\"url\", \"b64_json\"] | None = \"b64_json\"\n    size: ImageSize = \"auto\"\n    stream: bool | None = False\n    style: str | None = \"vivid\"\n    user: str | None = None\n    advanced_params: AdvancedImageParams | None = None\n    # Internal flag for benchmark mode - set by API, preserved through serialization\n    bench: bool = False\n\n    @field_validator(\"size\", mode=\"before\")\n    @classmethod\n    def normalize_size(cls, v: object) -> ImageSize:\n        return normalize_image_size(v)\n\n\nclass BenchImageGenerationTaskParams(ImageGenerationTaskParams):\n    bench: bool = True\n\n\nclass ImageEditsTaskParams(BaseModel):\n    \"\"\"Internal task params for image-editing requests.\"\"\"\n\n    image_data: str = \"\"  # Base64-encoded image (empty when using chunked transfer)\n    total_input_chunks: int = 0\n    prompt: str\n    model: str\n    n: int | None = 1\n    quality: Literal[\"high\", \"medium\", \"low\"] | None = \"medium\"\n    output_format: Literal[\"png\", \"jpeg\", \"webp\"] = \"png\"\n    response_format: Literal[\"url\", \"b64_json\"] | None = \"b64_json\"\n    size: ImageSize = \"auto\"\n    image_strength: float | None = 0.7\n    stream: bool = False\n    partial_images: int | None = 0\n    advanced_params: AdvancedImageParams | None = None\n    bench: bool = False\n\n    @field_validator(\"size\", mode=\"before\")\n    @classmethod\n    def normalize_size(cls, v: object) -> ImageSize:\n        return normalize_image_size(v)\n\n    def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:\n        for name, value in super().__repr_args__():  # pyright: ignore[reportAny]\n            if name == \"image_data\":\n                yield name, f\"<{len(self.image_data)} chars>\"\n            elif name is not None:\n                yield name, value\n\n\nclass ImageData(BaseModel):\n    b64_json: str | None = None\n    url: str | None = None\n    revised_prompt: str | None = None\n\n    def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:\n        for name, value in super().__repr_args__():  # pyright: ignore[reportAny]\n            if name == \"b64_json\" and self.b64_json is not None:\n                yield name, f\"<{len(self.b64_json)} chars>\"\n            elif name is not None:\n                yield name, value\n\n\nclass ImageGenerationResponse(BaseModel):\n    created: int = Field(default_factory=lambda: int(time.time()))\n    data: list[ImageData]\n\n\nclass BenchImageGenerationResponse(ImageGenerationResponse):\n    generation_stats: ImageGenerationStats | None = None\n    power_usage: PowerUsage | None = None\n\n\nclass ImageListItem(BaseModel, frozen=True):\n    image_id: str\n    url: str\n    content_type: str\n    expires_at: float\n\n\nclass ImageListResponse(BaseModel, frozen=True):\n    data: list[ImageListItem]\n\n\nclass StartDownloadParams(CamelCaseModel):\n    target_node_id: NodeId\n    shard_metadata: ShardMetadata\n\n\nclass StartDownloadResponse(CamelCaseModel):\n    command_id: CommandId\n\n\nclass DeleteDownloadResponse(CamelCaseModel):\n    command_id: CommandId\n\n\nclass TraceEventResponse(CamelCaseModel):\n    name: str\n    start_us: int\n    duration_us: int\n    rank: int\n    category: str\n\n\nclass TraceResponse(CamelCaseModel):\n    task_id: str\n    traces: list[TraceEventResponse]\n\n\nclass TraceCategoryStats(CamelCaseModel):\n    total_us: int\n    count: int\n    min_us: int\n    max_us: int\n    avg_us: float\n\n\nclass TraceRankStats(CamelCaseModel):\n    by_category: dict[str, TraceCategoryStats]\n\n\nclass TraceStatsResponse(CamelCaseModel):\n    task_id: str\n    total_wall_time_us: int\n    by_category: dict[str, TraceCategoryStats]\n    by_rank: dict[int, TraceRankStats]\n\n\nclass TraceListItem(CamelCaseModel):\n    task_id: str\n    created_at: str\n    file_size: int\n\n\nclass TraceListResponse(CamelCaseModel):\n    traces: list[TraceListItem]\n\n\nclass DeleteTracesRequest(CamelCaseModel):\n    task_ids: list[str]\n\n\nclass DeleteTracesResponse(CamelCaseModel):\n    deleted: list[str]\n    not_found: list[str]\n"
  },
  {
    "path": "src/exo/api/types/claude_api.py",
    "content": "\"\"\"Claude Messages API types for request/response conversion.\"\"\"\n\nfrom typing import Any, Literal\n\nfrom pydantic import BaseModel, Field\n\nfrom exo.shared.types.common import ModelId\n\n# Tool definition types\nClaudeToolInputSchema = dict[str, Any]\n\n\nclass ClaudeToolDefinition(BaseModel, frozen=True):\n    \"\"\"Tool definition in Claude Messages API request.\"\"\"\n\n    name: str\n    description: str | None = None\n    input_schema: ClaudeToolInputSchema\n\n\n# Type aliases\nClaudeRole = Literal[\"user\", \"assistant\"]\nClaudeStopReason = Literal[\"end_turn\", \"max_tokens\", \"stop_sequence\", \"tool_use\"]\n\n\n# Content block types\nclass ClaudeTextBlock(BaseModel, frozen=True):\n    \"\"\"Text content block in Claude Messages API.\"\"\"\n\n    type: Literal[\"text\"] = \"text\"\n    text: str\n\n\nclass ClaudeImageSource(BaseModel, frozen=True):\n    \"\"\"Image source for Claude image blocks.\"\"\"\n\n    type: Literal[\"base64\", \"url\"]\n    media_type: str | None = None\n    data: str | None = None\n    url: str | None = None\n\n\nclass ClaudeImageBlock(BaseModel, frozen=True):\n    \"\"\"Image content block in Claude Messages API.\"\"\"\n\n    type: Literal[\"image\"] = \"image\"\n    source: ClaudeImageSource\n\n\nclass ClaudeThinkingBlock(BaseModel, frozen=True):\n    \"\"\"Thinking content block in Claude Messages API.\"\"\"\n\n    type: Literal[\"thinking\"] = \"thinking\"\n    thinking: str\n    signature: str | None = None\n\n\nclass ClaudeToolUseBlock(BaseModel, frozen=True):\n    \"\"\"Tool use content block in Claude Messages API.\"\"\"\n\n    type: Literal[\"tool_use\"] = \"tool_use\"\n    id: str\n    name: str\n    input: dict[str, Any]\n\n\nclass ClaudeToolResultBlock(BaseModel, frozen=True):\n    \"\"\"Tool result content block in Claude Messages API request.\"\"\"\n\n    type: Literal[\"tool_result\"] = \"tool_result\"\n    tool_use_id: str\n    content: str | list[ClaudeTextBlock] | None = None\n    is_error: bool | None = None\n    cache_control: dict[str, str] | None = None\n\n\nClaudeContentBlock = (\n    ClaudeTextBlock | ClaudeImageBlock | ClaudeThinkingBlock | ClaudeToolUseBlock\n)\n\n# Input content blocks can also include tool_result (sent by user after tool_use)\nClaudeInputContentBlock = (\n    ClaudeTextBlock\n    | ClaudeImageBlock\n    | ClaudeThinkingBlock\n    | ClaudeToolUseBlock\n    | ClaudeToolResultBlock\n)\n\n\n# Request types\nclass ClaudeMessage(BaseModel, frozen=True):\n    \"\"\"Message in Claude Messages API request.\"\"\"\n\n    role: ClaudeRole\n    content: str | list[ClaudeInputContentBlock]\n\n\nclass ClaudeThinkingConfig(BaseModel, frozen=True):\n    type: Literal[\"enabled\", \"disabled\", \"adaptive\"]\n    budget_tokens: int | None = None\n\n\nclass ClaudeMessagesRequest(BaseModel):\n    \"\"\"Request body for Claude Messages API.\"\"\"\n\n    model: ModelId\n    max_tokens: int\n    messages: list[ClaudeMessage]\n    system: str | list[ClaudeTextBlock] | None = None\n    stop_sequences: list[str] | None = None\n    stream: bool = False\n    temperature: float | None = None\n    top_p: float | None = None\n    top_k: int | None = None\n    tools: list[ClaudeToolDefinition] | None = None\n    metadata: dict[str, str] | None = None\n    thinking: ClaudeThinkingConfig | None = None\n\n\n# Response types\nclass ClaudeUsage(BaseModel, frozen=True):\n    \"\"\"Token usage in Claude Messages API response.\"\"\"\n\n    input_tokens: int\n    output_tokens: int\n\n\nclass ClaudeMessagesResponse(BaseModel, frozen=True):\n    \"\"\"Response body for Claude Messages API.\"\"\"\n\n    id: str\n    type: Literal[\"message\"] = \"message\"\n    role: Literal[\"assistant\"] = \"assistant\"\n    content: list[ClaudeContentBlock]\n    model: str\n    stop_reason: ClaudeStopReason | None = None\n    stop_sequence: str | None = None\n    usage: ClaudeUsage\n\n\n# Streaming event types\nclass ClaudeMessageStart(BaseModel, frozen=True):\n    \"\"\"Partial message in message_start event.\"\"\"\n\n    id: str\n    type: Literal[\"message\"] = \"message\"\n    role: Literal[\"assistant\"] = \"assistant\"\n    content: list[ClaudeTextBlock] = Field(default_factory=list)\n    model: str\n    stop_reason: ClaudeStopReason | None = None\n    stop_sequence: str | None = None\n    usage: ClaudeUsage\n\n\nclass ClaudeMessageStartEvent(BaseModel, frozen=True):\n    \"\"\"Event sent at start of message stream.\"\"\"\n\n    type: Literal[\"message_start\"] = \"message_start\"\n    message: ClaudeMessageStart\n\n\nclass ClaudeContentBlockStartEvent(BaseModel, frozen=True):\n    \"\"\"Event sent at start of a content block.\"\"\"\n\n    type: Literal[\"content_block_start\"] = \"content_block_start\"\n    index: int\n    content_block: ClaudeTextBlock | ClaudeThinkingBlock | ClaudeToolUseBlock\n\n\nclass ClaudeTextDelta(BaseModel, frozen=True):\n    \"\"\"Delta for text content block.\"\"\"\n\n    type: Literal[\"text_delta\"] = \"text_delta\"\n    text: str\n\n\nclass ClaudeThinkingDelta(BaseModel, frozen=True):\n    \"\"\"Delta for thinking content block.\"\"\"\n\n    type: Literal[\"thinking_delta\"] = \"thinking_delta\"\n    thinking: str\n\n\nclass ClaudeInputJsonDelta(BaseModel, frozen=True):\n    \"\"\"Delta for tool use input JSON content block.\"\"\"\n\n    type: Literal[\"input_json_delta\"] = \"input_json_delta\"\n    partial_json: str\n\n\nclass ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):\n    \"\"\"Event sent for content block delta.\"\"\"\n\n    type: Literal[\"content_block_delta\"] = \"content_block_delta\"\n    index: int\n    delta: ClaudeTextDelta | ClaudeThinkingDelta | ClaudeInputJsonDelta\n\n\nclass ClaudeContentBlockStopEvent(BaseModel, frozen=True):\n    \"\"\"Event sent at end of a content block.\"\"\"\n\n    type: Literal[\"content_block_stop\"] = \"content_block_stop\"\n    index: int\n\n\nclass ClaudeMessageDeltaUsage(BaseModel, frozen=True):\n    \"\"\"Usage in message_delta event.\"\"\"\n\n    output_tokens: int\n\n\nclass ClaudeMessageDelta(BaseModel, frozen=True):\n    \"\"\"Delta in message_delta event.\"\"\"\n\n    stop_reason: ClaudeStopReason | None = None\n    stop_sequence: str | None = None\n\n\nclass ClaudeMessageDeltaEvent(BaseModel, frozen=True):\n    \"\"\"Event sent with final message delta.\"\"\"\n\n    type: Literal[\"message_delta\"] = \"message_delta\"\n    delta: ClaudeMessageDelta\n    usage: ClaudeMessageDeltaUsage\n\n\nclass ClaudeMessageStopEvent(BaseModel, frozen=True):\n    \"\"\"Event sent at end of message stream.\"\"\"\n\n    type: Literal[\"message_stop\"] = \"message_stop\"\n\n\nClaudeStreamEvent = (\n    ClaudeMessageStartEvent\n    | ClaudeContentBlockStartEvent\n    | ClaudeContentBlockDeltaEvent\n    | ClaudeContentBlockStopEvent\n    | ClaudeMessageDeltaEvent\n    | ClaudeMessageStopEvent\n)\n"
  },
  {
    "path": "src/exo/api/types/ollama_api.py",
    "content": "from __future__ import annotations\n\nimport time\nfrom typing import Any, Literal\n\nfrom pydantic import BaseModel, Field\n\nfrom exo.shared.models.model_cards import ModelId\n\n# https://github.com/ollama/ollama/blob/main/docs/api.md\n\nOllamaRole = Literal[\"system\", \"user\", \"assistant\", \"tool\"]\nOllamaDoneReason = Literal[\"stop\", \"length\", \"tool_call\", \"error\"]\n\n\nclass OllamaToolFunction(BaseModel, frozen=True):\n    name: str\n    arguments: dict[str, Any] | str\n    index: int | None = None\n\n\nclass OllamaToolCall(BaseModel, frozen=True):\n    id: str | None = None\n    type: Literal[\"function\"] | None = None\n    function: OllamaToolFunction\n\n\nclass OllamaMessage(BaseModel, frozen=True):\n    role: OllamaRole\n    content: str | None = None\n    thinking: str | None = None\n    tool_calls: list[OllamaToolCall] | None = None\n    name: str | None = None\n    tool_name: str | None = None\n    images: list[str] | None = None\n\n\nclass OllamaOptions(BaseModel, frozen=True):\n    num_predict: int | None = None\n    temperature: float | None = None\n    top_p: float | None = None\n    top_k: int | None = None\n    stop: str | list[str] | None = None\n    seed: int | None = None\n\n\nclass OllamaChatRequest(BaseModel, frozen=True):\n    model: ModelId\n    messages: list[OllamaMessage]\n    stream: bool = True\n    options: OllamaOptions | None = None\n    tools: list[dict[str, Any]] | None = None\n    format: Literal[\"json\"] | dict[str, Any] | None = None\n    keep_alive: str | int | None = None\n    think: bool | None = None\n\n\nclass OllamaGenerateRequest(BaseModel, frozen=True):\n    model: ModelId\n    prompt: str = \"\"\n    system: str | None = None\n    stream: bool = True\n    options: OllamaOptions | None = None\n    format: Literal[\"json\"] | dict[str, Any] | None = None\n    keep_alive: str | int | None = None\n    think: bool | None = None\n    raw: bool = False\n\n\nclass OllamaGenerateResponse(BaseModel, frozen=True, strict=True):\n    model: str\n    created_at: str = Field(\n        default_factory=lambda: time.strftime(\"%Y-%m-%dT%H:%M:%SZ\", time.gmtime())\n    )\n    response: str\n    thinking: str | None = None\n    done: bool\n    done_reason: OllamaDoneReason | None = None\n    total_duration: int | None = None\n    load_duration: int | None = None\n    prompt_eval_count: int | None = None\n    prompt_eval_duration: int | None = None\n    eval_count: int | None = None\n    eval_duration: int | None = None\n\n\nclass OllamaShowRequest(BaseModel, frozen=True):\n    name: str | None = None\n    model: str | None = None\n    verbose: bool | None = None\n\n\nclass OllamaChatResponse(BaseModel, frozen=True, strict=True):\n    model: str\n    created_at: str = Field(\n        default_factory=lambda: time.strftime(\"%Y-%m-%dT%H:%M:%SZ\", time.gmtime())\n    )\n    message: OllamaMessage\n    done: bool\n    done_reason: OllamaDoneReason | None = None\n    total_duration: int | None = None\n    load_duration: int | None = None\n    prompt_eval_count: int | None = None\n    prompt_eval_duration: int | None = None\n    eval_count: int | None = None\n    eval_duration: int | None = None\n\n\nclass OllamaModelDetails(BaseModel, frozen=True, strict=True):\n    format: str | None = None\n    family: str | None = None\n    parameter_size: str | None = None\n    quantization_level: str | None = None\n\n\nclass OllamaModelTag(BaseModel, frozen=True, strict=True):\n    name: str\n    model: str | None = None\n    modified_at: str | None = None\n    size: int | None = None\n    digest: str | None = None\n    details: OllamaModelDetails | None = None\n\n\nclass OllamaTagsResponse(BaseModel, frozen=True, strict=True):\n    models: list[OllamaModelTag]\n\n\nclass OllamaShowResponse(BaseModel, frozen=True, strict=True):\n    modelfile: str | None = None\n    parameters: str | None = None\n    template: str | None = None\n    details: OllamaModelDetails | None = None\n    model_info: dict[str, Any] | None = None\n\n\nclass OllamaPsModel(BaseModel, frozen=True, strict=True):\n    name: str\n    model: str\n    size: int\n    digest: str | None = None\n    details: OllamaModelDetails | None = None\n    expires_at: str | None = None\n    size_vram: int | None = None\n\n\nclass OllamaPsResponse(BaseModel, frozen=True, strict=True):\n    models: list[OllamaPsModel]\n"
  },
  {
    "path": "src/exo/api/types/openai_responses.py",
    "content": "\"\"\"OpenAI Responses API wire types.\n\nThese types model the OpenAI Responses API request/response format.\nResponsesRequest is the API-level wire type; for the canonical internal\ntask params type used by the inference pipeline, see\n``exo.shared.types.text_generation.TextGenerationTaskParams``.\n\"\"\"\n\nimport time\nfrom typing import Any, Literal\n\nfrom pydantic import BaseModel, Field\n\nfrom exo.shared.types.common import ModelId\nfrom exo.shared.types.text_generation import ReasoningEffort\n\n# Type aliases\nResponseStatus = Literal[\"completed\", \"failed\", \"in_progress\", \"incomplete\"]\nResponseRole = Literal[\"user\", \"assistant\", \"system\", \"developer\"]\n\n\n# Request input content part types\nclass ResponseInputTextPart(BaseModel, frozen=True):\n    \"\"\"Text content part in a Responses API input message.\"\"\"\n\n    type: Literal[\"input_text\"] = \"input_text\"\n    text: str\n\n\nclass ResponseOutputTextPart(BaseModel, frozen=True):\n    \"\"\"Output text content part (used when replaying assistant messages in input).\"\"\"\n\n    type: Literal[\"output_text\"] = \"output_text\"\n    text: str\n\n\nResponseContentPart = ResponseInputTextPart | ResponseOutputTextPart\n\n\n# Request input item types\nclass ResponseInputMessage(BaseModel, frozen=True):\n    \"\"\"Input message for Responses API.\"\"\"\n\n    role: ResponseRole\n    content: str | list[ResponseContentPart]\n    type: Literal[\"message\"] = \"message\"\n\n\nclass FunctionCallInputItem(BaseModel, frozen=True):\n    \"\"\"Function call item replayed in input (from a previous assistant response).\"\"\"\n\n    type: Literal[\"function_call\"] = \"function_call\"\n    id: str | None = None\n    call_id: str\n    name: str\n    arguments: str\n    status: ResponseStatus | None = None\n\n\nclass FunctionCallOutputInputItem(BaseModel, frozen=True):\n    \"\"\"Function call output item in input (user providing tool results).\"\"\"\n\n    type: Literal[\"function_call_output\"] = \"function_call_output\"\n    call_id: str\n    output: str\n    id: str | None = None\n    status: ResponseStatus | None = None\n\n\nResponseInputItem = (\n    ResponseInputMessage | FunctionCallInputItem | FunctionCallOutputInputItem\n)\n\n\nclass Reasoning(BaseModel, frozen=True):\n    \"\"\"Reasoning configuration for OpenAI Responses API.\"\"\"\n\n    effort: ReasoningEffort | None = None\n    summary: Literal[\"auto\", \"concise\", \"detailed\"] | None = None\n\n\nclass ResponsesRequest(BaseModel, frozen=True):\n    \"\"\"Request body for OpenAI Responses API.\n\n    This is the API wire type for the Responses endpoint. The canonical\n    internal task params type is ``TextGenerationTaskParams``; see the\n    ``responses_request_to_text_generation`` adapter for conversion.\n    \"\"\"\n\n    # --- OpenAI Responses API standard fields ---\n    model: ModelId\n    input: str | list[ResponseInputItem]\n    instructions: str | None = None\n    max_output_tokens: int | None = None\n    temperature: float | None = None\n    top_p: float | None = None\n    stream: bool = False\n    tools: list[dict[str, Any]] | None = None\n    metadata: dict[str, str] | None = None\n    reasoning: Reasoning | None = None\n\n    # --- exo extensions (not in OpenAI Responses API spec) ---\n    enable_thinking: bool | None = Field(\n        default=None,\n        description=\"[exo extension] Boolean thinking toggle. Not part of the OpenAI Responses API.\",\n        json_schema_extra={\"x-exo-extension\": True},\n    )\n\n    top_k: int | None = Field(\n        default=None,\n        description=\"[exo extension] Top-k sampling parameter. Not part of the OpenAI Responses API.\",\n        json_schema_extra={\"x-exo-extension\": True},\n    )\n    stop: str | list[str] | None = Field(\n        default=None,\n        description=\"[exo extension] Stop sequence(s). Not part of the OpenAI Responses API.\",\n        json_schema_extra={\"x-exo-extension\": True},\n    )\n    seed: int | None = Field(\n        default=None,\n        description=\"[exo extension] Seed for deterministic sampling. Not part of the OpenAI Responses API.\",\n        json_schema_extra={\"x-exo-extension\": True},\n    )\n\n    # --- Internal fields (preserved during serialization, hidden from OpenAPI schema) ---\n    chat_template_messages: list[dict[str, Any]] | None = Field(\n        default=None,\n        description=\"Internal: pre-formatted messages for tokenizer chat template. Not part of the OpenAI Responses API.\",\n        json_schema_extra={\"x-exo-internal\": True},\n    )\n\n\n# Response types\nclass ResponseOutputText(BaseModel, frozen=True):\n    \"\"\"Text content in response output.\"\"\"\n\n    type: Literal[\"output_text\"] = \"output_text\"\n    text: str\n    annotations: list[dict[str, str]] = Field(default_factory=list)\n\n\nclass ResponseMessageItem(BaseModel, frozen=True):\n    \"\"\"Message item in response output array.\"\"\"\n\n    type: Literal[\"message\"] = \"message\"\n    id: str\n    role: Literal[\"assistant\"] = \"assistant\"\n    content: list[ResponseOutputText]\n    status: ResponseStatus = \"completed\"\n\n\nclass ResponseFunctionCallItem(BaseModel, frozen=True):\n    \"\"\"Function call item in response output array.\"\"\"\n\n    type: Literal[\"function_call\"] = \"function_call\"\n    id: str\n    call_id: str\n    name: str\n    arguments: str\n    status: ResponseStatus = \"completed\"\n\n\nclass ResponseReasoningSummaryText(BaseModel, frozen=True):\n    \"\"\"Summary text part in a reasoning output item.\"\"\"\n\n    type: Literal[\"summary_text\"] = \"summary_text\"\n    text: str\n\n\nclass ResponseReasoningItem(BaseModel, frozen=True):\n    \"\"\"Reasoning output item in response output array.\"\"\"\n\n    type: Literal[\"reasoning\"] = \"reasoning\"\n    id: str\n    summary: list[ResponseReasoningSummaryText] = Field(default_factory=list)\n    status: ResponseStatus = \"completed\"\n\n\nResponseItem = ResponseMessageItem | ResponseFunctionCallItem | ResponseReasoningItem\n\n\nclass ResponseUsage(BaseModel, frozen=True):\n    \"\"\"Token usage in Responses API response.\"\"\"\n\n    input_tokens: int\n    output_tokens: int\n    total_tokens: int\n\n\nclass ResponsesResponse(BaseModel, frozen=True):\n    \"\"\"Response body for OpenAI Responses API.\"\"\"\n\n    id: str\n    object: Literal[\"response\"] = \"response\"\n    created_at: int = Field(default_factory=lambda: int(time.time()))\n    status: ResponseStatus = \"completed\"\n    model: str\n    output: list[ResponseItem]\n    output_text: str\n    usage: ResponseUsage | None = None\n\n\n# Streaming event types\nclass ResponseCreatedEvent(BaseModel, frozen=True):\n    \"\"\"Event sent when response is created.\"\"\"\n\n    type: Literal[\"response.created\"] = \"response.created\"\n    sequence_number: int\n    response: ResponsesResponse\n\n\nclass ResponseInProgressEvent(BaseModel, frozen=True):\n    \"\"\"Event sent when response starts processing.\"\"\"\n\n    type: Literal[\"response.in_progress\"] = \"response.in_progress\"\n    sequence_number: int\n    response: ResponsesResponse\n\n\nclass ResponseOutputItemAddedEvent(BaseModel, frozen=True):\n    \"\"\"Event sent when an output item is added.\"\"\"\n\n    type: Literal[\"response.output_item.added\"] = \"response.output_item.added\"\n    sequence_number: int\n    output_index: int\n    item: ResponseItem\n\n\nclass ResponseContentPartAddedEvent(BaseModel, frozen=True):\n    \"\"\"Event sent when a content part is added.\"\"\"\n\n    type: Literal[\"response.content_part.added\"] = \"response.content_part.added\"\n    sequence_number: int\n    item_id: str\n    output_index: int\n    content_index: int\n    part: ResponseOutputText\n\n\nclass ResponseTextDeltaEvent(BaseModel, frozen=True):\n    \"\"\"Event sent for text delta during streaming.\"\"\"\n\n    type: Literal[\"response.output_text.delta\"] = \"response.output_text.delta\"\n    sequence_number: int\n    item_id: str\n    output_index: int\n    content_index: int\n    delta: str\n\n\nclass ResponseTextDoneEvent(BaseModel, frozen=True):\n    \"\"\"Event sent when text content is done.\"\"\"\n\n    type: Literal[\"response.output_text.done\"] = \"response.output_text.done\"\n    sequence_number: int\n    item_id: str\n    output_index: int\n    content_index: int\n    text: str\n\n\nclass ResponseContentPartDoneEvent(BaseModel, frozen=True):\n    \"\"\"Event sent when a content part is done.\"\"\"\n\n    type: Literal[\"response.content_part.done\"] = \"response.content_part.done\"\n    sequence_number: int\n    item_id: str\n    output_index: int\n    content_index: int\n    part: ResponseOutputText\n\n\nclass ResponseOutputItemDoneEvent(BaseModel, frozen=True):\n    \"\"\"Event sent when an output item is done.\"\"\"\n\n    type: Literal[\"response.output_item.done\"] = \"response.output_item.done\"\n    sequence_number: int\n    output_index: int\n    item: ResponseItem\n\n\nclass ResponseFunctionCallArgumentsDeltaEvent(BaseModel, frozen=True):\n    \"\"\"Event sent for function call arguments delta during streaming.\"\"\"\n\n    type: Literal[\"response.function_call_arguments.delta\"] = (\n        \"response.function_call_arguments.delta\"\n    )\n    sequence_number: int\n    item_id: str\n    output_index: int\n    delta: str\n\n\nclass ResponseFunctionCallArgumentsDoneEvent(BaseModel, frozen=True):\n    \"\"\"Event sent when function call arguments are complete.\"\"\"\n\n    type: Literal[\"response.function_call_arguments.done\"] = (\n        \"response.function_call_arguments.done\"\n    )\n    sequence_number: int\n    item_id: str\n    output_index: int\n    name: str\n    arguments: str\n\n\nclass ResponseReasoningSummaryPartAddedEvent(BaseModel, frozen=True):\n    \"\"\"Event sent when a reasoning summary part is added.\"\"\"\n\n    type: Literal[\"response.reasoning_summary_part.added\"] = (\n        \"response.reasoning_summary_part.added\"\n    )\n    sequence_number: int\n    item_id: str\n    output_index: int\n    summary_index: int\n    part: ResponseReasoningSummaryText\n\n\nclass ResponseReasoningSummaryTextDeltaEvent(BaseModel, frozen=True):\n    \"\"\"Event sent for reasoning summary text delta during streaming.\"\"\"\n\n    type: Literal[\"response.reasoning_summary_text.delta\"] = (\n        \"response.reasoning_summary_text.delta\"\n    )\n    sequence_number: int\n    item_id: str\n    output_index: int\n    summary_index: int\n    delta: str\n\n\nclass ResponseReasoningSummaryTextDoneEvent(BaseModel, frozen=True):\n    \"\"\"Event sent when reasoning summary text is done.\"\"\"\n\n    type: Literal[\"response.reasoning_summary_text.done\"] = (\n        \"response.reasoning_summary_text.done\"\n    )\n    sequence_number: int\n    item_id: str\n    output_index: int\n    summary_index: int\n    text: str\n\n\nclass ResponseReasoningSummaryPartDoneEvent(BaseModel, frozen=True):\n    \"\"\"Event sent when a reasoning summary part is done.\"\"\"\n\n    type: Literal[\"response.reasoning_summary_part.done\"] = (\n        \"response.reasoning_summary_part.done\"\n    )\n    sequence_number: int\n    item_id: str\n    output_index: int\n    summary_index: int\n    part: ResponseReasoningSummaryText\n\n\nclass ResponseCompletedEvent(BaseModel, frozen=True):\n    \"\"\"Event sent when response is completed.\"\"\"\n\n    type: Literal[\"response.completed\"] = \"response.completed\"\n    sequence_number: int\n    response: ResponsesResponse\n\n\nResponsesStreamEvent = (\n    ResponseCreatedEvent\n    | ResponseInProgressEvent\n    | ResponseOutputItemAddedEvent\n    | ResponseContentPartAddedEvent\n    | ResponseTextDeltaEvent\n    | ResponseTextDoneEvent\n    | ResponseContentPartDoneEvent\n    | ResponseOutputItemDoneEvent\n    | ResponseFunctionCallArgumentsDeltaEvent\n    | ResponseFunctionCallArgumentsDoneEvent\n    | ResponseReasoningSummaryPartAddedEvent\n    | ResponseReasoningSummaryTextDeltaEvent\n    | ResponseReasoningSummaryTextDoneEvent\n    | ResponseReasoningSummaryPartDoneEvent\n    | ResponseCompletedEvent\n)\n"
  },
  {
    "path": "src/exo/download/coordinator.py",
    "content": "from dataclasses import dataclass, field\n\nimport anyio\nfrom anyio import current_time\nfrom loguru import logger\n\nfrom exo.download.download_utils import (\n    RepoDownloadProgress,\n    delete_model,\n    map_repo_download_progress_to_download_progress_data,\n    resolve_model_in_path,\n)\nfrom exo.download.shard_downloader import ShardDownloader\nfrom exo.shared.constants import EXO_MODELS_DIR, EXO_MODELS_PATH\nfrom exo.shared.models.model_cards import ModelId, get_model_cards\nfrom exo.shared.types.commands import (\n    CancelDownload,\n    DeleteDownload,\n    ForwarderDownloadCommand,\n    StartDownload,\n)\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.events import (\n    Event,\n    NodeDownloadProgress,\n)\nfrom exo.shared.types.worker.downloads import (\n    DownloadCompleted,\n    DownloadFailed,\n    DownloadOngoing,\n    DownloadPending,\n    DownloadProgress,\n)\nfrom exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata\nfrom exo.utils.channels import Receiver, Sender\nfrom exo.utils.task_group import TaskGroup\n\n\n@dataclass\nclass DownloadCoordinator:\n    node_id: NodeId\n    shard_downloader: ShardDownloader\n    download_command_receiver: Receiver[ForwarderDownloadCommand]\n    event_sender: Sender[Event]\n    offline: bool = False\n\n    # Local state\n    download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)\n    active_downloads: dict[ModelId, anyio.CancelScope] = field(default_factory=dict)\n\n    _tg: TaskGroup = field(init=False, default_factory=TaskGroup)\n\n    # Per-model throttle for download progress events\n    _last_progress_time: dict[ModelId, float] = field(default_factory=dict)\n\n    def __post_init__(self) -> None:\n        self.shard_downloader.on_progress(self._download_progress_callback)\n\n    def _model_dir(self, model_id: ModelId) -> str:\n        return str(EXO_MODELS_DIR / model_id.normalize())\n\n    async def _download_progress_callback(\n        self, callback_shard: ShardMetadata, progress: RepoDownloadProgress\n    ) -> None:\n        model_id = callback_shard.model_card.model_id\n        throttle_interval_secs = 1.0\n\n        if progress.status == \"complete\":\n            completed = DownloadCompleted(\n                shard_metadata=callback_shard,\n                node_id=self.node_id,\n                total=progress.total,\n                model_directory=self._model_dir(model_id),\n            )\n            self.download_status[model_id] = completed\n            await self.event_sender.send(\n                NodeDownloadProgress(download_progress=completed)\n            )\n            self._last_progress_time.pop(model_id, None)\n        elif (\n            progress.status == \"in_progress\"\n            and current_time() - self._last_progress_time.get(model_id, 0.0)\n            > throttle_interval_secs\n        ):\n            ongoing = DownloadOngoing(\n                node_id=self.node_id,\n                shard_metadata=callback_shard,\n                download_progress=map_repo_download_progress_to_download_progress_data(\n                    progress\n                ),\n                model_directory=self._model_dir(model_id),\n            )\n            self.download_status[model_id] = ongoing\n            await self.event_sender.send(\n                NodeDownloadProgress(download_progress=ongoing)\n            )\n            self._last_progress_time[model_id] = current_time()\n\n    async def run(self) -> None:\n        logger.info(\n            f\"Starting DownloadCoordinator{' (offline mode)' if self.offline else ''}\"\n        )\n        async with self._tg as tg:\n            tg.start_soon(self._command_processor)\n            tg.start_soon(self._emit_existing_download_progress)\n\n    def shutdown(self) -> None:\n        self._tg.cancel_tasks()\n\n    async def _command_processor(self) -> None:\n        with self.download_command_receiver as commands:\n            async for cmd in commands:\n                # Only process commands targeting this node\n                if cmd.command.target_node_id != self.node_id:\n                    continue\n\n                match cmd.command:\n                    case StartDownload(shard_metadata=shard):\n                        await self._start_download(shard)\n                    case DeleteDownload(model_id=model_id):\n                        await self._delete_download(model_id)\n                    case CancelDownload(model_id=model_id):\n                        await self._cancel_download(model_id)\n\n    async def _cancel_download(self, model_id: ModelId) -> None:\n        if model_id in self.active_downloads and model_id in self.download_status:\n            logger.info(f\"Cancelling download for {model_id}\")\n            self.active_downloads[model_id].cancel()\n            current_status = self.download_status[model_id]\n            pending = DownloadPending(\n                shard_metadata=current_status.shard_metadata,\n                node_id=self.node_id,\n                model_directory=self._model_dir(model_id),\n            )\n            self.download_status[model_id] = pending\n            await self.event_sender.send(\n                NodeDownloadProgress(download_progress=pending)\n            )\n\n    async def _start_download(self, shard: ShardMetadata) -> None:\n        model_id = shard.model_card.model_id\n\n        # Check if already downloading, complete, or recently failed\n        if model_id in self.download_status:\n            status = self.download_status[model_id]\n            if isinstance(status, (DownloadOngoing, DownloadCompleted, DownloadFailed)):\n                logger.debug(\n                    f\"Download for {model_id} already in progress, complete, or failed, skipping\"\n                )\n                return\n\n        # Check EXO_MODELS_PATH for pre-downloaded models\n        found_path = resolve_model_in_path(model_id)\n        if found_path is not None:\n            logger.info(\n                f\"DownloadCoordinator: Model {model_id} found in EXO_MODELS_PATH at {found_path}\"\n            )\n            completed = DownloadCompleted(\n                shard_metadata=shard,\n                node_id=self.node_id,\n                total=shard.model_card.storage_size,\n                model_directory=str(found_path),\n                read_only=True,\n            )\n            self.download_status[model_id] = completed\n            await self.event_sender.send(\n                NodeDownloadProgress(download_progress=completed)\n            )\n            return\n\n        # Emit pending status\n        progress = DownloadPending(\n            shard_metadata=shard,\n            node_id=self.node_id,\n            model_directory=self._model_dir(model_id),\n        )\n        self.download_status[model_id] = progress\n        await self.event_sender.send(NodeDownloadProgress(download_progress=progress))\n\n        # Check initial status from downloader\n        initial_progress = (\n            await self.shard_downloader.get_shard_download_status_for_shard(shard)\n        )\n\n        if initial_progress.status == \"complete\":\n            completed = DownloadCompleted(\n                shard_metadata=shard,\n                node_id=self.node_id,\n                total=initial_progress.total,\n                model_directory=self._model_dir(model_id),\n            )\n            self.download_status[model_id] = completed\n            await self.event_sender.send(\n                NodeDownloadProgress(download_progress=completed)\n            )\n            return\n\n        if self.offline:\n            logger.warning(\n                f\"Offline mode: model {model_id} is not fully available locally, cannot download\"\n            )\n            failed = DownloadFailed(\n                shard_metadata=shard,\n                node_id=self.node_id,\n                error_message=f\"Model files not found locally in offline mode: {model_id}\",\n                model_directory=self._model_dir(model_id),\n            )\n            self.download_status[model_id] = failed\n            await self.event_sender.send(NodeDownloadProgress(download_progress=failed))\n            return\n\n        # Start actual download\n        self._start_download_task(shard, initial_progress)\n\n    def _start_download_task(\n        self, shard: ShardMetadata, initial_progress: RepoDownloadProgress\n    ) -> None:\n        model_id = shard.model_card.model_id\n\n        # Emit ongoing status\n        status = DownloadOngoing(\n            node_id=self.node_id,\n            shard_metadata=shard,\n            download_progress=map_repo_download_progress_to_download_progress_data(\n                initial_progress\n            ),\n            model_directory=self._model_dir(model_id),\n        )\n        self.download_status[model_id] = status\n        self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))\n\n        async def download_wrapper(cancel_scope: anyio.CancelScope) -> None:\n            try:\n                with cancel_scope:\n                    await self.shard_downloader.ensure_shard(shard)\n            except Exception as e:\n                logger.error(f\"Download failed for {model_id}: {e}\")\n                failed = DownloadFailed(\n                    shard_metadata=shard,\n                    node_id=self.node_id,\n                    error_message=str(e),\n                    model_directory=self._model_dir(model_id),\n                )\n                self.download_status[model_id] = failed\n                await self.event_sender.send(\n                    NodeDownloadProgress(download_progress=failed)\n                )\n            except anyio.get_cancelled_exc_class():\n                # ignore cancellation - let cleanup do its thing\n                pass\n            finally:\n                self.active_downloads.pop(model_id, None)\n\n        scope = anyio.CancelScope()\n        self._tg.start_soon(download_wrapper, scope)\n        self.active_downloads[model_id] = scope\n\n    async def _delete_download(self, model_id: ModelId) -> None:\n        # Protect read-only models (from EXO_MODELS_PATH) from deletion\n        if model_id in self.download_status:\n            current = self.download_status[model_id]\n            if isinstance(current, DownloadCompleted) and current.read_only:\n                logger.warning(\n                    f\"Refusing to delete read-only model {model_id} (from EXO_MODELS_PATH)\"\n                )\n                return\n\n        # Cancel if active\n        if model_id in self.active_downloads:\n            logger.info(f\"Cancelling active download for {model_id} before deletion\")\n            self.active_downloads[model_id].cancel()\n\n        # Delete from disk\n        logger.info(f\"Deleting model files for {model_id}\")\n        deleted = await delete_model(model_id)\n\n        if deleted:\n            logger.info(f\"Successfully deleted model {model_id}\")\n        else:\n            logger.warning(f\"Model {model_id} was not found on disk\")\n\n        # Emit pending status to reset UI state, then remove from local tracking\n        if model_id in self.download_status:\n            current_status = self.download_status[model_id]\n            pending = DownloadPending(\n                shard_metadata=current_status.shard_metadata,\n                node_id=self.node_id,\n                model_directory=self._model_dir(model_id),\n            )\n            await self.event_sender.send(\n                NodeDownloadProgress(download_progress=pending)\n            )\n            del self.download_status[model_id]\n\n    async def _emit_existing_download_progress(self) -> None:\n        while True:\n            try:\n                logger.debug(\n                    \"DownloadCoordinator: Fetching and emitting existing download progress...\"\n                )\n                async for (\n                    _,\n                    progress,\n                ) in self.shard_downloader.get_shard_download_status():\n                    model_id = progress.shard.model_card.model_id\n\n                    # Active downloads emit progress via the callback — don't overwrite\n                    if model_id in self.active_downloads:\n                        continue\n\n                    if progress.status == \"complete\":\n                        status: DownloadProgress = DownloadCompleted(\n                            node_id=self.node_id,\n                            shard_metadata=progress.shard,\n                            total=progress.total,\n                            model_directory=self._model_dir(\n                                progress.shard.model_card.model_id\n                            ),\n                        )\n                    elif progress.status in [\"in_progress\", \"not_started\"]:\n                        if progress.downloaded_this_session.in_bytes == 0:\n                            status = DownloadPending(\n                                node_id=self.node_id,\n                                shard_metadata=progress.shard,\n                                model_directory=self._model_dir(\n                                    progress.shard.model_card.model_id\n                                ),\n                                downloaded=progress.downloaded,\n                                total=progress.total,\n                            )\n                        else:\n                            status = DownloadOngoing(\n                                node_id=self.node_id,\n                                shard_metadata=progress.shard,\n                                download_progress=map_repo_download_progress_to_download_progress_data(\n                                    progress\n                                ),\n                                model_directory=self._model_dir(\n                                    progress.shard.model_card.model_id\n                                ),\n                            )\n                    else:\n                        continue\n\n                    self.download_status[progress.shard.model_card.model_id] = status\n                    await self.event_sender.send(\n                        NodeDownloadProgress(download_progress=status)\n                    )\n                # Scan EXO_MODELS_PATH for pre-downloaded models\n                if EXO_MODELS_PATH is not None:\n                    for card in await get_model_cards():\n                        mid = card.model_id\n                        if mid in self.active_downloads:\n                            continue\n                        if isinstance(\n                            self.download_status.get(mid),\n                            (DownloadCompleted, DownloadOngoing, DownloadFailed),\n                        ):\n                            continue\n                        found = resolve_model_in_path(mid)\n                        if found is not None:\n                            path_shard = PipelineShardMetadata(\n                                model_card=card,\n                                device_rank=0,\n                                world_size=1,\n                                start_layer=0,\n                                end_layer=card.n_layers,\n                                n_layers=card.n_layers,\n                            )\n                            path_completed: DownloadProgress = DownloadCompleted(\n                                node_id=self.node_id,\n                                shard_metadata=path_shard,\n                                total=card.storage_size,\n                                model_directory=str(found),\n                                read_only=True,\n                            )\n                            self.download_status[mid] = path_completed\n                            await self.event_sender.send(\n                                NodeDownloadProgress(download_progress=path_completed)\n                            )\n\n                logger.debug(\n                    \"DownloadCoordinator: Done emitting existing download progress.\"\n                )\n            except Exception as e:\n                logger.error(\n                    f\"DownloadCoordinator: Error emitting existing download progress: {e}\"\n                )\n            await anyio.sleep(60)\n"
  },
  {
    "path": "src/exo/download/download_utils.py",
    "content": "import asyncio\nimport hashlib\nimport os\nimport shutil\nimport ssl\nimport time\nimport traceback\nfrom collections.abc import Awaitable\nfrom datetime import timedelta\nfrom pathlib import Path\nfrom typing import Callable, Literal\nfrom urllib.parse import urljoin\n\nimport aiofiles\nimport aiofiles.os as aios\nimport aiohttp\nimport certifi\nfrom huggingface_hub import (\n    snapshot_download,  # pyright: ignore[reportUnknownVariableType]\n)\nfrom loguru import logger\nfrom pydantic import (\n    TypeAdapter,\n)\n\nfrom exo.download.huggingface_utils import (\n    filter_repo_objects,\n    get_allow_patterns,\n    get_auth_headers,\n    get_hf_endpoint,\n    get_hf_token,\n)\nfrom exo.shared.constants import EXO_MODELS_DIR, EXO_MODELS_PATH\nfrom exo.shared.models.model_cards import ModelTask\nfrom exo.shared.types.common import ModelId\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.worker.downloads import (\n    DownloadProgressData,\n    FileListEntry,\n    ModelSafetensorsIndex,\n    RepoDownloadProgress,\n    RepoFileDownloadProgress,\n)\nfrom exo.shared.types.worker.shards import ShardMetadata\n\n\nclass HuggingFaceAuthenticationError(Exception):\n    \"\"\"Raised when HuggingFace returns 401/403 for a model download.\"\"\"\n\n\nclass HuggingFaceRateLimitError(Exception):\n    \"\"\"429 Huggingface code\"\"\"\n\n\nasync def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:\n    token = await get_hf_token()\n    if status_code == 401 and token is None:\n        return (\n            f\"Model '{model_id}' requires authentication. \"\n            f\"Set HF_TOKEN in the app's Advanced settings, set the HF_TOKEN environment variable, or run `hf auth login`. \"\n            f\"Get a token at https://huggingface.co/settings/tokens\"\n        )\n    elif status_code == 403:\n        return (\n            f\"Access denied to '{model_id}'. \"\n            f\"Please accept the model terms at https://huggingface.co/{model_id}\"\n        )\n    else:\n        return f\"Authentication failed for '{model_id}' (HTTP {status_code})\"\n\n\ndef trim_etag(etag: str) -> str:\n    if (etag[0] == '\"' and etag[-1] == '\"') or (etag[0] == \"'\" and etag[-1] == \"'\"):\n        return etag[1:-1]\n    return etag\n\n\ndef map_repo_file_download_progress_to_download_progress_data(\n    repo_file_download_progress: RepoFileDownloadProgress,\n) -> DownloadProgressData:\n    return DownloadProgressData(\n        downloaded=repo_file_download_progress.downloaded,\n        downloaded_this_session=repo_file_download_progress.downloaded_this_session,\n        total=repo_file_download_progress.total,\n        completed_files=1 if repo_file_download_progress.status == \"complete\" else 0,\n        total_files=1,\n        speed=repo_file_download_progress.speed,\n        eta_ms=int(repo_file_download_progress.eta.total_seconds() * 1000),\n        files={},\n    )\n\n\ndef map_repo_download_progress_to_download_progress_data(\n    repo_download_progress: RepoDownloadProgress,\n) -> DownloadProgressData:\n    return DownloadProgressData(\n        total=repo_download_progress.total,\n        downloaded=repo_download_progress.downloaded,\n        downloaded_this_session=repo_download_progress.downloaded_this_session,\n        completed_files=repo_download_progress.completed_files,\n        total_files=repo_download_progress.total_files,\n        speed=repo_download_progress.overall_speed,\n        eta_ms=int(repo_download_progress.overall_eta.total_seconds() * 1000),\n        files={\n            file_path: map_repo_file_download_progress_to_download_progress_data(\n                file_progress\n            )\n            for file_path, file_progress in repo_download_progress.file_progress.items()\n        },\n    )\n\n\ndef resolve_model_in_path(model_id: ModelId) -> Path | None:\n    \"\"\"Search EXO_MODELS_PATH directories for a pre-existing model.\n\n    Checks each directory for the normalized name (org--model).  A candidate\n    is only returned if ``is_model_directory_complete`` confirms all weight\n    files are present.\n    \"\"\"\n    if EXO_MODELS_PATH is None:\n        return None\n    normalized = model_id.normalize()\n    for search_dir in EXO_MODELS_PATH:\n        candidate = search_dir / normalized\n        if candidate.is_dir() and is_model_directory_complete(candidate):\n            return candidate\n    return None\n\n\ndef build_model_path(model_id: ModelId) -> Path:\n    found = resolve_model_in_path(model_id)\n    if found is not None:\n        return found\n    return EXO_MODELS_DIR / model_id.normalize()\n\n\nasync def resolve_model_path_for_repo(model_id: ModelId) -> Path:\n    return (await ensure_models_dir()) / model_id.normalize()\n\n\nasync def ensure_models_dir() -> Path:\n    await aios.makedirs(EXO_MODELS_DIR, exist_ok=True)\n    return EXO_MODELS_DIR\n\n\nasync def delete_model(model_id: ModelId) -> bool:\n    models_dir = await ensure_models_dir()\n    model_dir = models_dir / model_id.normalize()\n    cache_dir = models_dir / \"caches\" / model_id.normalize()\n\n    deleted = False\n    if await aios.path.exists(model_dir):\n        await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)\n        deleted = True\n\n    # Also clear cache\n    if await aios.path.exists(cache_dir):\n        await asyncio.to_thread(shutil.rmtree, cache_dir, ignore_errors=False)\n\n    return deleted\n\n\nasync def seed_models(seed_dir: str | Path):\n    \"\"\"Move models from resources folder to EXO_MODELS_DIR.\"\"\"\n    source_dir = Path(seed_dir)\n    dest_dir = await ensure_models_dir()\n    for path in source_dir.iterdir():\n        if path.is_dir() and path.name.startswith(\"models--\"):\n            dest_path = dest_dir / path.name\n            if await aios.path.exists(dest_path):\n                logger.info(\"Skipping moving model to .cache directory\")\n            else:\n                try:\n                    await aios.rename(str(path), str(dest_path))\n                except Exception:\n                    logger.error(f\"Error seeding model {path} to {dest_path}\")\n                    logger.error(traceback.format_exc())\n\n\ndef _scan_model_directory(\n    model_dir: Path, recursive: bool = False\n) -> list[FileListEntry] | None:\n    \"\"\"Scan a local model directory and build a file list.\n\n    Requires at least one ``*.safetensors.index.json``.  Every weight file\n    referenced by the index that is missing on disk gets ``size=None``.\n    \"\"\"\n    index_files = list(model_dir.glob(\"**/*.safetensors.index.json\"))\n    if not index_files:\n        return None\n\n    entries_by_path: dict[str, FileListEntry] = {}\n\n    if recursive:\n        for dirpath, _, filenames in os.walk(model_dir):\n            for filename in filenames:\n                if filename.endswith(\".partial\"):\n                    continue\n                full_path = Path(dirpath) / filename\n                rel_path = str(full_path.relative_to(model_dir))\n                entries_by_path[rel_path] = FileListEntry(\n                    type=\"file\",\n                    path=rel_path,\n                    size=full_path.stat().st_size,\n                )\n    else:\n        for item in model_dir.iterdir():\n            if item.is_file() and not item.name.endswith(\".partial\"):\n                entries_by_path[item.name] = FileListEntry(\n                    type=\"file\",\n                    path=item.name,\n                    size=item.stat().st_size,\n                )\n\n    # Add expected weight files from index that haven't been downloaded yet\n    for index_file in index_files:\n        try:\n            index_data = ModelSafetensorsIndex.model_validate_json(\n                index_file.read_text()\n            )\n            relative_dir = index_file.parent.relative_to(model_dir)\n            for filename in set(index_data.weight_map.values()):\n                rel_path = (\n                    str(relative_dir / filename)\n                    if relative_dir != Path(\".\")\n                    else filename\n                )\n                if rel_path not in entries_by_path:\n                    entries_by_path[rel_path] = FileListEntry(\n                        type=\"file\",\n                        path=rel_path,\n                        size=None,\n                    )\n        except Exception:\n            continue\n\n    return list(entries_by_path.values())\n\n\ndef is_model_directory_complete(model_dir: Path) -> bool:\n    \"\"\"Check if a model directory contains all required weight files.\"\"\"\n    file_list = _scan_model_directory(model_dir, recursive=True)\n    return file_list is not None and all(f.size is not None for f in file_list)\n\n\nasync def _build_file_list_from_local_directory(\n    model_id: ModelId,\n    recursive: bool = False,\n) -> list[FileListEntry] | None:\n    \"\"\"Build a file list from locally existing model files.\n\n    We can only figure out the files we need from safetensors index, so\n    a local directory must contain a *.safetensors.index.json and\n    safetensors listed there.\n    \"\"\"\n    model_dir = (await ensure_models_dir()) / model_id.normalize()\n    if not await aios.path.exists(model_dir):\n        return None\n\n    file_list = await asyncio.to_thread(_scan_model_directory, model_dir, recursive)\n    if not file_list:\n        return None\n    return file_list\n\n\n_fetched_file_lists_this_session: set[str] = set()\n\n\nasync def fetch_file_list_with_cache(\n    model_id: ModelId,\n    revision: str = \"main\",\n    recursive: bool = False,\n    skip_internet: bool = False,\n    on_connection_lost: Callable[[], None] = lambda: None,\n) -> list[FileListEntry]:\n    target_dir = (await ensure_models_dir()) / \"caches\" / model_id.normalize()\n    await aios.makedirs(target_dir, exist_ok=True)\n    cache_file = target_dir / f\"{model_id.normalize()}--{revision}--file_list.json\"\n    cache_key = f\"{model_id.normalize()}--{revision}\"\n\n    if cache_key in _fetched_file_lists_this_session and await aios.path.exists(\n        cache_file\n    ):\n        async with aiofiles.open(cache_file, \"r\") as f:\n            return TypeAdapter(list[FileListEntry]).validate_json(await f.read())\n\n    if skip_internet:\n        if await aios.path.exists(cache_file):\n            async with aiofiles.open(cache_file, \"r\") as f:\n                return TypeAdapter(list[FileListEntry]).validate_json(await f.read())\n        local_file_list = await _build_file_list_from_local_directory(\n            model_id, recursive\n        )\n        if local_file_list is not None:\n            logger.warning(\n                f\"No internet and no cached file list for {model_id} - using local file list\"\n            )\n            return local_file_list\n        raise FileNotFoundError(\n            f\"No internet connection and no cached file list for {model_id}\"\n        )\n\n    try:\n        file_list = await fetch_file_list_with_retry(\n            model_id,\n            revision,\n            recursive=recursive,\n            on_connection_lost=on_connection_lost,\n        )\n        async with aiofiles.open(cache_file, \"w\") as f:\n            await f.write(\n                TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()\n            )\n        _fetched_file_lists_this_session.add(cache_key)\n        return file_list\n    except Exception as e:\n        logger.opt(exception=e).warning(\n            \"Ran into exception when fetching file list from HF.\"\n        )\n\n        if await aios.path.exists(cache_file):\n            logger.warning(\n                f\"No cached file list for {model_id} - using local file list\"\n            )\n            async with aiofiles.open(cache_file, \"r\") as f:\n                return TypeAdapter(list[FileListEntry]).validate_json(await f.read())\n        local_file_list = await _build_file_list_from_local_directory(\n            model_id, recursive\n        )\n        if local_file_list is not None:\n            logger.warning(\n                f\"Failed to fetch file list for {model_id} and no cache exists, \"\n            )\n            return local_file_list\n        raise FileNotFoundError(f\"Failed to fetch file list for {model_id}: {e}\") from e\n\n\nasync def fetch_file_list_with_retry(\n    model_id: ModelId,\n    revision: str = \"main\",\n    path: str = \"\",\n    recursive: bool = False,\n    on_connection_lost: Callable[[], None] = lambda: None,\n) -> list[FileListEntry]:\n    n_attempts = 3\n    for attempt in range(n_attempts):\n        try:\n            return await _fetch_file_list(model_id, revision, path, recursive)\n        except HuggingFaceAuthenticationError:\n            raise\n        except Exception as e:\n            on_connection_lost()\n            if attempt == n_attempts - 1:\n                raise e\n            await asyncio.sleep(2.0**attempt)\n    raise Exception(\n        f\"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}\"\n    )\n\n\nasync def _fetch_file_list(\n    model_id: ModelId, revision: str = \"main\", path: str = \"\", recursive: bool = False\n) -> list[FileListEntry]:\n    api_url = f\"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}\"\n    url = f\"{api_url}/{path}\" if path else api_url\n\n    headers = await get_download_headers()\n    async with (\n        create_http_session(timeout_profile=\"short\") as session,\n        session.get(url, headers=headers) as response,\n    ):\n        if response.status in [401, 403]:\n            msg = await _build_auth_error_message(response.status, model_id)\n            raise HuggingFaceAuthenticationError(msg)\n        elif response.status == 429:\n            raise HuggingFaceRateLimitError(\n                f\"Couldn't download {model_id} because of HuggingFace rate limit.\"\n            )\n        elif response.status == 200:\n            data_json = await response.text()\n            data = TypeAdapter(list[FileListEntry]).validate_json(data_json)\n            files: list[FileListEntry] = []\n            for item in data:\n                if item.type == \"file\":\n                    files.append(FileListEntry.model_validate(item))\n                elif item.type == \"directory\" and recursive:\n                    subfiles = await _fetch_file_list(\n                        model_id, revision, item.path, recursive\n                    )\n                    files.extend(subfiles)\n            return files\n        else:\n            raise Exception(f\"Failed to fetch file list: {response.status}\")\n\n\nasync def get_download_headers() -> dict[str, str]:\n    return {**(await get_auth_headers()), \"Accept-Encoding\": \"identity\"}\n\n\ndef create_http_session(\n    auto_decompress: bool = False,\n    timeout_profile: Literal[\"short\", \"long\"] = \"long\",\n) -> aiohttp.ClientSession:\n    if timeout_profile == \"short\":\n        total_timeout = 30\n        connect_timeout = 10\n        sock_read_timeout = 30\n        sock_connect_timeout = 10\n    else:\n        total_timeout = 1800\n        connect_timeout = 60\n        sock_read_timeout = 60\n        sock_connect_timeout = 60\n\n    ssl_context = ssl.create_default_context(\n        cafile=os.getenv(\"SSL_CERT_FILE\") or certifi.where()\n    )\n    connector = aiohttp.TCPConnector(ssl=ssl_context)\n\n    return aiohttp.ClientSession(\n        auto_decompress=auto_decompress,\n        connector=connector,\n        proxy=os.getenv(\"HTTPS_PROXY\") or os.getenv(\"HTTP_PROXY\") or None,\n        timeout=aiohttp.ClientTimeout(\n            total=total_timeout,\n            connect=connect_timeout,\n            sock_read=sock_read_timeout,\n            sock_connect=sock_connect_timeout,\n        ),\n    )\n\n\nasync def calc_hash(path: Path, hash_type: Literal[\"sha1\", \"sha256\"] = \"sha1\") -> str:\n    hasher = hashlib.sha1() if hash_type == \"sha1\" else hashlib.sha256()\n    if hash_type == \"sha1\":\n        header = f\"blob {(await aios.stat(path)).st_size}\\0\".encode()\n        hasher.update(header)\n    async with aiofiles.open(path, \"rb\") as f:\n        while chunk := await f.read(8 * 1024 * 1024):\n            hasher.update(chunk)\n    return hasher.hexdigest()\n\n\nasync def file_meta(\n    model_id: ModelId, revision: str, path: str, redirected_location: str | None = None\n) -> tuple[int, str]:\n    url = (\n        urljoin(f\"{get_hf_endpoint()}/{model_id}/resolve/{revision}/\", path)\n        if redirected_location is None\n        else f\"{get_hf_endpoint()}{redirected_location}\"\n    )\n    headers = await get_download_headers()\n    async with (\n        create_http_session(timeout_profile=\"short\") as session,\n        session.head(url, headers=headers) as r,\n    ):\n        if r.status == 307:\n            # On redirect, only trust Hugging Face's x-linked-* headers.\n            x_linked_size = r.headers.get(\"x-linked-size\")\n            x_linked_etag = r.headers.get(\"x-linked-etag\")\n            if x_linked_size and x_linked_etag:\n                content_length = int(x_linked_size)\n                etag = trim_etag(x_linked_etag)\n                return content_length, etag\n            # Otherwise, follow the redirect to get authoritative size/hash\n            redirected_location = r.headers.get(\"location\")\n            return await file_meta(model_id, revision, path, redirected_location)\n        if r.status in [401, 403]:\n            msg = await _build_auth_error_message(r.status, model_id)\n            raise HuggingFaceAuthenticationError(msg)\n        content_length = int(\n            r.headers.get(\"x-linked-size\") or r.headers.get(\"content-length\") or 0\n        )\n        etag = r.headers.get(\"x-linked-etag\") or r.headers.get(\"etag\")\n        assert content_length > 0, f\"No content length for {url}\"\n        assert etag is not None, f\"No remote hash for {url}\"\n        etag = trim_etag(etag)\n        return content_length, etag\n\n\nasync def download_file_with_retry(\n    model_id: ModelId,\n    revision: str,\n    path: str,\n    target_dir: Path,\n    on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,\n    on_connection_lost: Callable[[], None] = lambda: None,\n    skip_internet: bool = False,\n) -> Path:\n    n_attempts = 3\n    for attempt in range(n_attempts):\n        try:\n            return await _download_file(\n                model_id, revision, path, target_dir, on_progress, skip_internet\n            )\n        except HuggingFaceAuthenticationError:\n            raise\n        except FileNotFoundError:\n            raise\n        except HuggingFaceRateLimitError as e:\n            if attempt == n_attempts - 1:\n                raise e\n            logger.error(\n                f\"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}\"\n            )\n            logger.error(traceback.format_exc())\n            await asyncio.sleep(2.0**attempt)\n        except Exception as e:\n            if attempt == n_attempts - 1:\n                on_connection_lost()\n                raise e\n            logger.error(\n                f\"Download error on attempt {attempt + 1}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}\"\n            )\n            logger.error(traceback.format_exc())\n            await asyncio.sleep(2.0**attempt)\n    raise Exception(\n        f\"Failed to download file {model_id=} {revision=} {path=} {target_dir=}\"\n    )\n\n\nasync def _download_file(\n    model_id: ModelId,\n    revision: str,\n    path: str,\n    target_dir: Path,\n    on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,\n    skip_internet: bool = False,\n) -> Path:\n    target_path = target_dir / path\n\n    if await aios.path.exists(target_path):\n        if skip_internet:\n            return target_path\n\n        local_size = (await aios.stat(target_path)).st_size\n\n        # Try to verify against remote, but allow offline operation\n        try:\n            remote_size, _ = await file_meta(model_id, revision, path)\n            if local_size != remote_size:\n                logger.info(\n                    f\"File {path} size mismatch (local={local_size}, remote={remote_size}), re-downloading\"\n                )\n                await aios.remove(target_path)\n            else:\n                return target_path\n        except Exception as e:\n            # Offline or network error - trust local file\n            logger.debug(\n                f\"Could not verify {path} against remote (offline?): {e}, using local file\"\n            )\n            return target_path\n\n    if skip_internet:\n        raise FileNotFoundError(\n            f\"File {path} not found locally and cannot download in offline mode\"\n        )\n\n    await aios.makedirs((target_dir / path).parent, exist_ok=True)\n    length, etag = await file_meta(model_id, revision, path)\n    remote_hash = etag[:-5] if etag.endswith(\"-gzip\") else etag\n    partial_path = target_dir / f\"{path}.partial\"\n    resume_byte_pos = (\n        (await aios.stat(partial_path)).st_size\n        if (await aios.path.exists(partial_path))\n        else None\n    )\n    if resume_byte_pos != length:\n        url = urljoin(f\"{get_hf_endpoint()}/{model_id}/resolve/{revision}/\", path)\n        headers = await get_download_headers()\n        if resume_byte_pos:\n            headers[\"Range\"] = f\"bytes={resume_byte_pos}-\"\n        n_read = resume_byte_pos or 0\n        async with (\n            create_http_session(timeout_profile=\"long\") as session,\n            session.get(url, headers=headers) as r,\n        ):\n            if r.status == 404:\n                raise FileNotFoundError(f\"File not found: {url}\")\n            if r.status in [401, 403]:\n                msg = await _build_auth_error_message(r.status, model_id)\n                raise HuggingFaceAuthenticationError(msg)\n            assert r.status in [200, 206], (\n                f\"Failed to download {path} from {url}: {r.status}\"\n            )\n            async with aiofiles.open(\n                partial_path, \"ab\" if resume_byte_pos else \"wb\"\n            ) as f:\n                while chunk := await r.content.read(8 * 1024 * 1024):\n                    n_read = n_read + (await f.write(chunk))\n                    on_progress(n_read, length, False)\n\n    final_hash = await calc_hash(\n        partial_path, hash_type=\"sha256\" if len(remote_hash) == 64 else \"sha1\"\n    )\n    integrity = final_hash == remote_hash\n    if not integrity:\n        try:\n            await aios.remove(partial_path)\n        except Exception as e:\n            logger.error(f\"Error removing partial file {partial_path}: {e}\")\n        raise Exception(\n            f\"Downloaded file {target_dir / path} has hash {final_hash} but remote hash is {remote_hash}\"\n        )\n    await aios.rename(partial_path, target_dir / path)\n    on_progress(length, length, True)\n    return target_dir / path\n\n\ndef calculate_repo_progress(\n    shard: ShardMetadata,\n    model_id: ModelId,\n    revision: str,\n    file_progress: dict[str, RepoFileDownloadProgress],\n    all_start_time: float,\n) -> RepoDownloadProgress:\n    all_total = sum((p.total for p in file_progress.values()), Memory.from_bytes(0))\n    all_downloaded = sum(\n        (p.downloaded for p in file_progress.values()), Memory.from_bytes(0)\n    )\n    all_downloaded_this_session = sum(\n        (p.downloaded_this_session for p in file_progress.values()),\n        Memory.from_bytes(0),\n    )\n    elapsed_time = time.time() - all_start_time\n    all_speed = (\n        all_downloaded_this_session.in_bytes / elapsed_time if elapsed_time > 0 else 0\n    )\n    all_eta = (\n        timedelta(seconds=(all_total - all_downloaded).in_bytes / all_speed)\n        if all_speed > 0\n        else timedelta(seconds=0)\n    )\n    status = (\n        \"complete\"\n        if all(p.status == \"complete\" for p in file_progress.values())\n        else \"in_progress\"\n        if any(p.status == \"in_progress\" for p in file_progress.values())\n        else \"not_started\"\n    )\n    return RepoDownloadProgress(\n        repo_id=model_id,\n        repo_revision=revision,\n        shard=shard,\n        completed_files=len(\n            [p for p in file_progress.values() if p.downloaded == p.total]\n        ),\n        total_files=len(file_progress),\n        downloaded=all_downloaded,\n        downloaded_this_session=all_downloaded_this_session,\n        total=all_total,\n        overall_speed=all_speed,\n        overall_eta=all_eta,\n        status=status,\n        file_progress=file_progress,\n    )\n\n\nasync def get_weight_map(model_id: ModelId, revision: str = \"main\") -> dict[str, str]:\n    target_dir = (await ensure_models_dir()) / model_id.normalize()\n    await aios.makedirs(target_dir, exist_ok=True)\n\n    index_files_dir = snapshot_download(\n        repo_id=model_id,\n        local_dir=target_dir,\n        allow_patterns=\"*.safetensors.index.json\",\n    )\n\n    index_files = list(Path(index_files_dir).glob(\"**/*.safetensors.index.json\"))\n\n    weight_map: dict[str, str] = {}\n\n    for index_file in index_files:\n        relative_dir = index_file.parent.relative_to(index_files_dir)\n\n        async with aiofiles.open(index_file, \"r\") as f:\n            index_data = ModelSafetensorsIndex.model_validate_json(await f.read())\n\n            if relative_dir != Path(\".\"):\n                prefixed_weight_map = {\n                    f\"{relative_dir}/{key}\": str(relative_dir / value)\n                    for key, value in index_data.weight_map.items()\n                }\n                weight_map = weight_map | prefixed_weight_map\n            else:\n                weight_map = weight_map | index_data.weight_map\n\n    return weight_map\n\n\nasync def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:\n    # TODO: 'Smart' downloads are disabled because:\n    #  (i) We don't handle all kinds of files;\n    # (ii) We don't have sticky sessions.\n    # (iii) Tensor parallel requires all files.\n    return [\"*\"]\n    try:\n        weight_map = await get_weight_map(str(shard.model_card.model_id))\n        return get_allow_patterns(weight_map, shard)\n    except Exception:\n        logger.error(f\"Error getting weight map for {shard.model_card.model_id=}\")\n        logger.error(traceback.format_exc())\n        return [\"*\"]\n\n\ndef is_image_model(shard: ShardMetadata) -> bool:\n    tasks = shard.model_card.tasks\n    return ModelTask.TextToImage in tasks or ModelTask.ImageToImage in tasks\n\n\nasync def get_downloaded_size(path: Path) -> int:\n    partial_path = path.with_suffix(path.suffix + \".partial\")\n    if await aios.path.exists(path):\n        return (await aios.stat(path)).st_size\n    if await aios.path.exists(partial_path):\n        return (await aios.stat(partial_path)).st_size\n    return 0\n\n\nasync def download_shard(\n    shard: ShardMetadata,\n    on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],\n    max_parallel_downloads: int = 8,\n    skip_download: bool = False,\n    skip_internet: bool = False,\n    allow_patterns: list[str] | None = None,\n    on_connection_lost: Callable[[], None] = lambda: None,\n) -> tuple[Path, RepoDownloadProgress]:\n    if not skip_download:\n        logger.debug(f\"Downloading {shard.model_card.model_id=}\")\n\n    revision = \"main\"\n    target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(\n        \"/\", \"--\"\n    )\n    if not skip_download:\n        await aios.makedirs(target_dir, exist_ok=True)\n\n    if not allow_patterns:\n        allow_patterns = await resolve_allow_patterns(shard)\n\n    if not skip_download:\n        logger.debug(f\"Downloading {shard.model_card.model_id=} with {allow_patterns=}\")\n\n    all_start_time = time.time()\n    file_list = await fetch_file_list_with_cache(\n        shard.model_card.model_id,\n        revision,\n        recursive=True,\n        skip_internet=skip_internet,\n        on_connection_lost=on_connection_lost,\n    )\n    filtered_file_list = list(\n        filter_repo_objects(\n            file_list, allow_patterns=allow_patterns, key=lambda x: x.path\n        )\n    )\n\n    # For image models, skip root-level safetensors files since weights\n    # are stored in component subdirectories (e.g., transformer/, vae/)\n    if is_image_model(shard):\n        filtered_file_list = [\n            f\n            for f in filtered_file_list\n            if \"/\" in f.path or not f.path.endswith(\".safetensors\")\n        ]\n    file_progress: dict[str, RepoFileDownloadProgress] = {}\n\n    async def on_progress_wrapper(\n        file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool\n    ) -> None:\n        previous_progress = file_progress.get(file.path)\n\n        # Detect re-download: curr_bytes < previous downloaded means file was deleted and restarted\n        is_redownload = (\n            previous_progress is not None\n            and curr_bytes < previous_progress.downloaded.in_bytes\n        )\n\n        if is_redownload or previous_progress is None:\n            # Fresh download or re-download: reset tracking\n            start_time = time.time()\n            downloaded_this_session = curr_bytes\n        else:\n            # Continuing download: accumulate\n            start_time = previous_progress.start_time\n            downloaded_this_session = (\n                previous_progress.downloaded_this_session.in_bytes\n                + (curr_bytes - previous_progress.downloaded.in_bytes)\n            )\n\n        speed = (\n            downloaded_this_session / (time.time() - start_time)\n            if time.time() - start_time > 0\n            else 0\n        )\n        eta = (\n            timedelta(seconds=(total_bytes - curr_bytes) / speed)\n            if speed > 0\n            else timedelta(seconds=0)\n        )\n        file_progress[file.path] = RepoFileDownloadProgress(\n            repo_id=shard.model_card.model_id,\n            repo_revision=revision,\n            file_path=file.path,\n            downloaded=Memory.from_bytes(curr_bytes),\n            downloaded_this_session=Memory.from_bytes(downloaded_this_session),\n            total=Memory.from_bytes(total_bytes),\n            speed=speed,\n            eta=eta,\n            status=\"complete\"\n            if curr_bytes == total_bytes and is_renamed\n            else \"in_progress\",\n            start_time=start_time,\n        )\n        await on_progress(\n            shard,\n            calculate_repo_progress(\n                shard,\n                shard.model_card.model_id,\n                revision,\n                file_progress,\n                all_start_time,\n            ),\n        )\n\n    for file in filtered_file_list:\n        downloaded_bytes = await get_downloaded_size(target_dir / file.path)\n        final_file_exists = await aios.path.exists(target_dir / file.path)\n        file_progress[file.path] = RepoFileDownloadProgress(\n            repo_id=shard.model_card.model_id,\n            repo_revision=revision,\n            file_path=file.path,\n            downloaded=Memory.from_bytes(downloaded_bytes),\n            downloaded_this_session=Memory.from_bytes(0),\n            total=Memory.from_bytes(file.size or 0),\n            speed=0,\n            eta=timedelta(0),\n            status=\"complete\"\n            if final_file_exists and downloaded_bytes == file.size\n            else \"not_started\",\n            start_time=time.time(),\n        )\n\n    semaphore = asyncio.Semaphore(max_parallel_downloads)\n\n    def schedule_progress(\n        file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool\n    ) -> None:\n        asyncio.create_task(\n            on_progress_wrapper(file, curr_bytes, total_bytes, is_renamed)\n        )\n\n    async def download_with_semaphore(file: FileListEntry) -> None:\n        async with semaphore:\n            await download_file_with_retry(\n                shard.model_card.model_id,\n                revision,\n                file.path,\n                target_dir,\n                lambda curr_bytes, total_bytes, is_renamed: schedule_progress(\n                    file, curr_bytes, total_bytes, is_renamed\n                ),\n                on_connection_lost=on_connection_lost,\n                skip_internet=skip_internet,\n            )\n\n    if not skip_download:\n        await asyncio.gather(\n            *[download_with_semaphore(file) for file in filtered_file_list]\n        )\n    final_repo_progress = calculate_repo_progress(\n        shard, shard.model_card.model_id, revision, file_progress, all_start_time\n    )\n    await on_progress(shard, final_repo_progress)\n    if gguf := next((f for f in filtered_file_list if f.path.endswith(\".gguf\")), None):\n        return target_dir / gguf.path, final_repo_progress\n    else:\n        return target_dir, final_repo_progress\n"
  },
  {
    "path": "src/exo/download/huggingface_utils.py",
    "content": "import os\nfrom fnmatch import fnmatch\nfrom pathlib import Path\nfrom typing import Callable, Generator, Iterable\n\nimport aiofiles\nimport aiofiles.os as aios\nfrom loguru import logger\n\nfrom exo.shared.types.worker.shards import ShardMetadata\n\n\ndef filter_repo_objects[T](\n    items: Iterable[T],\n    *,\n    allow_patterns: list[str] | str | None = None,\n    ignore_patterns: list[str] | str | None = None,\n    key: Callable[[T], str] | None = None,\n) -> Generator[T, None, None]:\n    if isinstance(allow_patterns, str):\n        allow_patterns = [allow_patterns]\n    if isinstance(ignore_patterns, str):\n        ignore_patterns = [ignore_patterns]\n    if allow_patterns is not None:\n        allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]\n    if ignore_patterns is not None:\n        ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]\n\n    if key is None:\n\n        def _identity(item: T) -> str:\n            if isinstance(item, str):\n                return item\n            if isinstance(item, Path):\n                return str(item)\n            raise ValueError(\n                f\"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.\"\n            )\n\n        key = _identity\n\n    for item in items:\n        path = key(item)\n        if allow_patterns is not None and not any(\n            fnmatch(path, r) for r in allow_patterns\n        ):\n            continue\n        if ignore_patterns is not None and any(\n            fnmatch(path, r) for r in ignore_patterns\n        ):\n            continue\n        yield item\n\n\ndef _add_wildcard_to_directories(pattern: str) -> str:\n    if pattern[-1] == \"/\":\n        return pattern + \"*\"\n    return pattern\n\n\ndef get_hf_endpoint() -> str:\n    return os.environ.get(\"HF_ENDPOINT\", \"https://huggingface.co\")\n\n\ndef get_hf_home() -> Path:\n    \"\"\"Get the Hugging Face home directory.\"\"\"\n    return Path(os.environ.get(\"HF_HOME\", Path.home() / \".cache\" / \"huggingface\"))\n\n\nasync def get_hf_token() -> str | None:\n    \"\"\"Retrieve the Hugging Face token from HF_TOKEN env var or HF_HOME directory.\"\"\"\n    # Check environment variable first\n    if token := os.environ.get(\"HF_TOKEN\"):\n        return token\n    # Fall back to file-based token\n    token_path = get_hf_home() / \"token\"\n    if await aios.path.exists(token_path):\n        async with aiofiles.open(token_path, \"r\") as f:\n            return (await f.read()).strip()\n    return None\n\n\nasync def get_auth_headers() -> dict[str, str]:\n    \"\"\"Get authentication headers if a token is available.\"\"\"\n    token = await get_hf_token()\n    if token:\n        return {\"Authorization\": f\"Bearer {token}\"}\n    return {}\n\n\ndef extract_layer_num(tensor_name: str) -> int | None:\n    # This is a simple example and might need to be adjusted based on the actual naming convention\n    parts = tensor_name.split(\".\")\n    for part in parts:\n        if part.isdigit():\n            return int(part)\n    return None\n\n\ndef get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list[str]:\n    default_patterns = set(\n        [\n            \"*.json\",\n            \"*.py\",\n            \"tokenizer.model\",\n            \"tiktoken.model\",\n            \"*/spiece.model\",\n            \"*.tiktoken\",\n            \"*.txt\",\n            \"*.jinja\",\n        ]\n    )\n    shard_specific_patterns: set[str] = set()\n\n    if shard.model_card.components is not None:\n        shardable_component = next(\n            (c for c in shard.model_card.components if c.can_shard), None\n        )\n\n        if weight_map and shardable_component:\n            for tensor_name, filename in weight_map.items():\n                # Strip component prefix from tensor name (added by weight map namespacing)\n                # E.g., \"transformer/blocks.0.weight\" -> \"blocks.0.weight\"\n                if \"/\" in tensor_name:\n                    _, tensor_name_no_prefix = tensor_name.split(\"/\", 1)\n                else:\n                    tensor_name_no_prefix = tensor_name\n\n                # Determine which component this file belongs to from filename\n                component_path = Path(filename).parts[0] if \"/\" in filename else None\n\n                if component_path == shardable_component.component_path.rstrip(\"/\"):\n                    layer_num = extract_layer_num(tensor_name_no_prefix)\n                    if (\n                        layer_num is not None\n                        and shard.start_layer <= layer_num < shard.end_layer\n                    ):\n                        shard_specific_patterns.add(filename)\n\n                    if shard.is_first_layer or shard.is_last_layer:\n                        shard_specific_patterns.add(filename)\n                else:\n                    shard_specific_patterns.add(filename)\n\n        else:\n            shard_specific_patterns = set([\"*.safetensors\"])\n\n        # TODO(ciaran): temporary - Include all files from non-shardable components that have no index file\n        for component in shard.model_card.components:\n            if not component.can_shard and component.safetensors_index_filename is None:\n                component_pattern = f\"{component.component_path.rstrip('/')}/*\"\n                shard_specific_patterns.add(component_pattern)\n    else:\n        if weight_map:\n            for tensor_name, filename in weight_map.items():\n                layer_num = extract_layer_num(tensor_name)\n                if (\n                    layer_num is not None\n                    and shard.start_layer <= layer_num < shard.end_layer\n                ):\n                    shard_specific_patterns.add(filename)\n            layer_independent_files = set(\n                [v for k, v in weight_map.items() if extract_layer_num(k) is None]\n            )\n            shard_specific_patterns.update(layer_independent_files)\n            logger.debug(f\"get_allow_patterns {shard=} {layer_independent_files=}\")\n        else:\n            shard_specific_patterns = set([\"*.safetensors\"])\n\n    logger.info(f\"get_allow_patterns {shard=} {shard_specific_patterns=}\")\n    return list(default_patterns | shard_specific_patterns)\n"
  },
  {
    "path": "src/exo/download/impl_shard_downloader.py",
    "content": "import asyncio\nfrom asyncio import create_task\nfrom collections.abc import Awaitable\nfrom pathlib import Path\nfrom typing import AsyncIterator, Callable\n\nfrom loguru import logger\n\nfrom exo.download.download_utils import RepoDownloadProgress, download_shard\nfrom exo.download.shard_downloader import ShardDownloader\nfrom exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards\nfrom exo.shared.types.worker.shards import (\n    PipelineShardMetadata,\n    ShardMetadata,\n)\n\n\ndef exo_shard_downloader(\n    max_parallel_downloads: int = 8, offline: bool = False\n) -> ShardDownloader:\n    return SingletonShardDownloader(\n        ResumableShardDownloader(max_parallel_downloads, offline=offline)\n    )\n\n\nasync def build_base_shard(model_id: ModelId) -> ShardMetadata:\n    model_card = await ModelCard.load(model_id)\n    return PipelineShardMetadata(\n        model_card=model_card,\n        device_rank=0,\n        world_size=1,\n        start_layer=0,\n        end_layer=model_card.n_layers,\n        n_layers=model_card.n_layers,\n    )\n\n\nasync def build_full_shard(model_id: ModelId) -> PipelineShardMetadata:\n    base_shard = await build_base_shard(model_id)\n    return PipelineShardMetadata(\n        model_card=base_shard.model_card,\n        device_rank=base_shard.device_rank,\n        world_size=base_shard.world_size,\n        start_layer=base_shard.start_layer,\n        end_layer=base_shard.n_layers,\n        n_layers=base_shard.n_layers,\n    )\n\n\nclass SingletonShardDownloader(ShardDownloader):\n    def __init__(self, shard_downloader: ShardDownloader):\n        self.shard_downloader = shard_downloader\n        self.active_downloads: dict[ShardMetadata, asyncio.Task[Path]] = {}\n\n    def on_progress(\n        self,\n        callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],\n    ) -> None:\n        self.shard_downloader.on_progress(callback)\n\n    async def ensure_shard(\n        self, shard: ShardMetadata, config_only: bool = False\n    ) -> Path:\n        if shard not in self.active_downloads:\n            self.active_downloads[shard] = asyncio.create_task(\n                self.shard_downloader.ensure_shard(shard, config_only)\n            )\n        try:\n            return await self.active_downloads[shard]\n        finally:\n            if shard in self.active_downloads and self.active_downloads[shard].done():\n                del self.active_downloads[shard]\n\n    async def get_shard_download_status(\n        self,\n    ) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:\n        async for path, status in self.shard_downloader.get_shard_download_status():\n            yield path, status\n\n    async def get_shard_download_status_for_shard(\n        self, shard: ShardMetadata\n    ) -> RepoDownloadProgress:\n        return await self.shard_downloader.get_shard_download_status_for_shard(shard)\n\n\nclass ResumableShardDownloader(ShardDownloader):\n    def __init__(self, max_parallel_downloads: int = 8, offline: bool = False):\n        self.max_parallel_downloads = max_parallel_downloads\n        self.offline = offline\n        self.on_progress_callbacks: list[\n            Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]]\n        ] = []\n\n    async def on_progress_wrapper(\n        self, shard: ShardMetadata, progress: RepoDownloadProgress\n    ) -> None:\n        for callback in self.on_progress_callbacks:\n            await callback(shard, progress)\n\n    def on_progress(\n        self,\n        callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],\n    ) -> None:\n        self.on_progress_callbacks.append(callback)\n\n    async def ensure_shard(\n        self, shard: ShardMetadata, config_only: bool = False\n    ) -> Path:\n        allow_patterns = [\"config.json\"] if config_only else None\n\n        target_dir, _ = await download_shard(\n            shard,\n            self.on_progress_wrapper,\n            max_parallel_downloads=self.max_parallel_downloads,\n            allow_patterns=allow_patterns,\n            skip_internet=self.offline,\n        )\n        return target_dir\n\n    async def get_shard_download_status(\n        self,\n    ) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:\n        async def _status_for_model(\n            model_id: ModelId,\n        ) -> tuple[Path, RepoDownloadProgress]:\n            \"\"\"Helper coroutine that builds the shard for a model and gets its download status.\"\"\"\n            shard = await build_full_shard(model_id)\n            return await download_shard(\n                shard,\n                self.on_progress_wrapper,\n                skip_download=True,\n                skip_internet=self.offline,\n            )\n\n        semaphore = asyncio.Semaphore(self.max_parallel_downloads)\n\n        async def download_with_semaphore(\n            model_card: ModelCard,\n        ) -> tuple[Path, RepoDownloadProgress]:\n            async with semaphore:\n                return await _status_for_model(model_card.model_id)\n\n        tasks = [\n            create_task(download_with_semaphore(model_card))\n            for model_card in await get_model_cards()\n        ]\n\n        for task in asyncio.as_completed(tasks):\n            try:\n                yield await task\n            except Exception as e:\n                logger.warning(f\"Error downloading shard: {type(e).__name__}\")\n\n    async def get_shard_download_status_for_shard(\n        self, shard: ShardMetadata\n    ) -> RepoDownloadProgress:\n        _, progress = await download_shard(\n            shard,\n            self.on_progress_wrapper,\n            skip_download=True,\n            skip_internet=self.offline,\n        )\n        return progress\n"
  },
  {
    "path": "src/exo/download/shard_downloader.py",
    "content": "from abc import ABC, abstractmethod\nfrom collections.abc import Awaitable\nfrom copy import copy\nfrom datetime import timedelta\nfrom pathlib import Path\nfrom typing import AsyncIterator, Callable\n\nfrom exo.download.download_utils import RepoDownloadProgress\nfrom exo.shared.models.model_cards import ModelCard, ModelId, ModelTask\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.worker.shards import (\n    PipelineShardMetadata,\n    ShardMetadata,\n)\n\n\n# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?\nclass ShardDownloader(ABC):\n    @abstractmethod\n    async def ensure_shard(\n        self, shard: ShardMetadata, config_only: bool = False\n    ) -> Path:\n        \"\"\"\n        Ensures that the shard is downloaded.\n        Does not allow multiple overlapping downloads at once.\n        If you try to download a Shard which overlaps a Shard that is already being downloaded,\n        the download will be cancelled and a new download will start.\n\n        Args:\n            shard (Shard): The shard to download.\n        \"\"\"\n\n    @abstractmethod\n    def on_progress(\n        self,\n        callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],\n    ) -> None:\n        pass\n\n    @abstractmethod\n    async def get_shard_download_status(\n        self,\n    ) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:\n        \"\"\"Get the download status of shards.\n\n        Yields:\n            tuple[Path, RepoDownloadProgress]: The path and progress of a shard download.\n        \"\"\"\n        yield (Path(\"/tmp/noop_shard\"), NOOP_DOWNLOAD_PROGRESS)\n\n    @abstractmethod\n    async def get_shard_download_status_for_shard(\n        self, shard: ShardMetadata\n    ) -> RepoDownloadProgress: ...\n\n\nclass NoopShardDownloader(ShardDownloader):\n    async def ensure_shard(\n        self, shard: ShardMetadata, config_only: bool = False\n    ) -> Path:\n        return Path(\"/tmp/noop_shard\")\n\n    def on_progress(\n        self,\n        callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],\n    ) -> None:\n        pass\n\n    async def get_shard_download_status(\n        self,\n    ) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:\n        yield (\n            Path(\"/tmp/noop_shard\"),\n            NOOP_DOWNLOAD_PROGRESS,\n        )\n\n    async def get_shard_download_status_for_shard(\n        self, shard: ShardMetadata\n    ) -> RepoDownloadProgress:\n        dp = copy(NOOP_DOWNLOAD_PROGRESS)\n        dp.shard = shard\n        return dp\n\n\nNOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(\n    repo_id=\"noop\",\n    repo_revision=\"noop\",\n    shard=PipelineShardMetadata(\n        model_card=ModelCard(\n            model_id=ModelId(\"noop\"),\n            storage_size=Memory.from_bytes(0),\n            n_layers=1,\n            hidden_size=1,\n            supports_tensor=False,\n            tasks=[ModelTask.TextGeneration],\n        ),\n        device_rank=0,\n        world_size=1,\n        start_layer=0,\n        end_layer=1,\n        n_layers=1,\n    ),\n    completed_files=0,\n    total_files=0,\n    downloaded=Memory.from_bytes(0),\n    downloaded_this_session=Memory.from_bytes(0),\n    total=Memory.from_bytes(0),\n    overall_speed=0,\n    overall_eta=timedelta(seconds=0),\n    status=\"complete\",\n)\n"
  },
  {
    "path": "src/exo/download/tests/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/download/tests/test_download_verification.py",
    "content": "\"\"\"Tests for download verification and cache behavior.\"\"\"\n\nimport time\nfrom collections.abc import AsyncIterator\nfrom datetime import timedelta\nfrom pathlib import Path\nfrom unittest.mock import AsyncMock, MagicMock, patch\n\nimport aiofiles\nimport aiofiles.os as aios\nimport pytest\nfrom pydantic import TypeAdapter\n\nfrom exo.download.download_utils import (\n    delete_model,\n    fetch_file_list_with_cache,\n)\nfrom exo.shared.types.common import ModelId\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress\n\n\n@pytest.fixture\ndef model_id() -> ModelId:\n    return ModelId(\"test-org/test-model\")\n\n\n@pytest.fixture\nasync def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:\n    \"\"\"Set up a temporary models directory for testing.\"\"\"\n    models_dir = tmp_path / \"models\"\n    await aios.makedirs(models_dir, exist_ok=True)\n    with patch(\"exo.download.download_utils.EXO_MODELS_DIR\", models_dir):\n        yield models_dir\n\n\nclass TestFileVerification:\n    \"\"\"Tests for file size verification in _download_file.\"\"\"\n\n    async def test_redownload_when_file_size_changes_upstream(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"Test that files with mismatched sizes are re-downloaded.\"\"\"\n        # Import inside test to allow patching\n        from exo.download.download_utils import (\n            _download_file,  # pyright: ignore[reportPrivateUsage]\n        )\n\n        target_dir = tmp_path / \"downloads\"\n        await aios.makedirs(target_dir, exist_ok=True)\n\n        # Create a local file with wrong size\n        local_file = target_dir / \"test.safetensors\"\n        async with aiofiles.open(local_file, \"wb\") as f:\n            await f.write(b\"local content\")  # 13 bytes\n\n        remote_size = 1000  # Different from local\n        remote_hash = \"abc123\"\n\n        with (\n            patch(\n                \"exo.download.download_utils.file_meta\",\n                new_callable=AsyncMock,\n                return_value=(remote_size, remote_hash),\n            ) as mock_file_meta,\n            patch(\n                \"exo.download.download_utils.create_http_session\"\n            ) as mock_session_factory,\n        ):\n            # Set up mock HTTP response for re-download\n            mock_response = MagicMock()\n            mock_response.status = 200\n            mock_response.content.read = AsyncMock(  # pyright: ignore[reportAny]\n                side_effect=[b\"x\" * remote_size, b\"\"]\n            )\n\n            mock_session = MagicMock()\n            mock_session.get.return_value.__aenter__ = AsyncMock(  # pyright: ignore[reportAny]\n                return_value=mock_response\n            )\n            mock_session.get.return_value.__aexit__ = AsyncMock(  # pyright: ignore[reportAny]\n                return_value=None\n            )\n            mock_session_factory.return_value.__aenter__ = AsyncMock(  # pyright: ignore[reportAny]\n                return_value=mock_session\n            )\n            mock_session_factory.return_value.__aexit__ = AsyncMock(  # pyright: ignore[reportAny]\n                return_value=None\n            )\n\n            # Mock calc_hash to return the expected hash\n            with patch(\n                \"exo.download.download_utils.calc_hash\",\n                new_callable=AsyncMock,\n                return_value=remote_hash,\n            ):\n                await _download_file(model_id, \"main\", \"test.safetensors\", target_dir)\n\n            # file_meta should be called twice: once for verification, once for download\n            assert mock_file_meta.call_count == 2\n\n    async def test_skip_download_when_file_size_matches(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"Test that files with matching sizes are not re-downloaded.\"\"\"\n        from exo.download.download_utils import (\n            _download_file,  # pyright: ignore[reportPrivateUsage]\n        )\n\n        target_dir = tmp_path / \"downloads\"\n        await aios.makedirs(target_dir, exist_ok=True)\n\n        # Create a local file\n        local_file = target_dir / \"test.safetensors\"\n        local_content = b\"local content\"\n        async with aiofiles.open(local_file, \"wb\") as f:\n            await f.write(local_content)\n\n        remote_size = len(local_content)  # Same as local\n        remote_hash = \"abc123\"\n\n        with (\n            patch(\n                \"exo.download.download_utils.file_meta\",\n                new_callable=AsyncMock,\n                return_value=(remote_size, remote_hash),\n            ) as mock_file_meta,\n            patch(\n                \"exo.download.download_utils.create_http_session\"\n            ) as mock_session_factory,\n        ):\n            result = await _download_file(\n                model_id, \"main\", \"test.safetensors\", target_dir\n            )\n\n            # Should return immediately without downloading\n            assert result == local_file\n            mock_file_meta.assert_called_once()\n            mock_session_factory.assert_not_called()\n\n    async def test_offline_fallback_uses_local_file(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"Test that local files are used when network is unavailable.\"\"\"\n        from exo.download.download_utils import (\n            _download_file,  # pyright: ignore[reportPrivateUsage]\n        )\n\n        target_dir = tmp_path / \"downloads\"\n        await aios.makedirs(target_dir, exist_ok=True)\n\n        # Create a local file\n        local_file = target_dir / \"test.safetensors\"\n        async with aiofiles.open(local_file, \"wb\") as f:\n            await f.write(b\"local content\")\n\n        with (\n            patch(\n                \"exo.download.download_utils.file_meta\",\n                new_callable=AsyncMock,\n                side_effect=Exception(\"Network error\"),\n            ),\n            patch(\n                \"exo.download.download_utils.create_http_session\"\n            ) as mock_session_factory,\n        ):\n            result = await _download_file(\n                model_id, \"main\", \"test.safetensors\", target_dir\n            )\n\n            # Should return local file without attempting download\n            assert result == local_file\n            mock_session_factory.assert_not_called()\n\n\nclass TestFileListCache:\n    \"\"\"Tests for file list caching behavior.\"\"\"\n\n    async def test_fetch_fresh_and_update_cache(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"Test that fresh data is fetched and cache is updated.\"\"\"\n        models_dir = tmp_path / \"models\"\n\n        file_list = [\n            FileListEntry(type=\"file\", path=\"model.safetensors\", size=1000),\n            FileListEntry(type=\"file\", path=\"config.json\", size=100),\n        ]\n\n        with (\n            patch(\"exo.download.download_utils.EXO_MODELS_DIR\", models_dir),\n            patch(\n                \"exo.download.download_utils.fetch_file_list_with_retry\",\n                new_callable=AsyncMock,\n                return_value=file_list,\n            ) as mock_fetch,\n        ):\n            result = await fetch_file_list_with_cache(model_id, \"main\")\n\n            assert result == file_list\n            mock_fetch.assert_called_once()\n\n            # Verify cache was written\n            cache_file = (\n                models_dir\n                / \"caches\"\n                / model_id.normalize()\n                / f\"{model_id.normalize()}--main--file_list.json\"\n            )\n            assert await aios.path.exists(cache_file)\n\n            async with aiofiles.open(cache_file, \"r\") as f:\n                cached_data = TypeAdapter(list[FileListEntry]).validate_json(\n                    await f.read()\n                )\n            assert cached_data == file_list\n\n    async def test_fallback_to_cache_when_fetch_fails(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"Test that cached data is used when fetch fails.\"\"\"\n        models_dir = tmp_path / \"models\"\n        cache_dir = models_dir / \"caches\" / model_id.normalize()\n        await aios.makedirs(cache_dir, exist_ok=True)\n\n        # Create cache file\n        cached_file_list = [\n            FileListEntry(type=\"file\", path=\"model.safetensors\", size=1000),\n        ]\n        cache_file = cache_dir / f\"{model_id.normalize()}--main--file_list.json\"\n        async with aiofiles.open(cache_file, \"w\") as f:\n            await f.write(\n                TypeAdapter(list[FileListEntry]).dump_json(cached_file_list).decode()\n            )\n\n        with (\n            patch(\"exo.download.download_utils.EXO_MODELS_DIR\", models_dir),\n            patch(\n                \"exo.download.download_utils.fetch_file_list_with_retry\",\n                new_callable=AsyncMock,\n                side_effect=Exception(\"Network error\"),\n            ),\n        ):\n            result = await fetch_file_list_with_cache(model_id, \"main\")\n\n            assert result == cached_file_list\n\n    async def test_error_propagates_when_no_cache(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"Test that errors propagate when fetch fails and no cache exists.\"\"\"\n        models_dir = tmp_path / \"models\"\n\n        with (\n            patch(\"exo.download.download_utils.EXO_MODELS_DIR\", models_dir),\n            patch(\n                \"exo.download.download_utils.fetch_file_list_with_retry\",\n                new_callable=AsyncMock,\n                side_effect=Exception(\"Network error\"),\n            ),\n            pytest.raises(Exception, match=\"Network error\"),\n        ):\n            await fetch_file_list_with_cache(model_id, \"main\")\n\n\nclass TestModelDeletion:\n    \"\"\"Tests for model deletion including cache cleanup.\"\"\"\n\n    async def test_delete_model_clears_cache(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"Test that deleting a model also deletes its cache.\"\"\"\n        models_dir = tmp_path / \"models\"\n        model_dir = models_dir / model_id.normalize()\n        cache_dir = models_dir / \"caches\" / model_id.normalize()\n\n        # Create model and cache directories\n        await aios.makedirs(model_dir, exist_ok=True)\n        await aios.makedirs(cache_dir, exist_ok=True)\n\n        # Add some files\n        async with aiofiles.open(model_dir / \"model.safetensors\", \"w\") as f:\n            await f.write(\"model data\")\n        async with aiofiles.open(cache_dir / \"file_list.json\", \"w\") as f:\n            await f.write(\"[]\")\n\n        with patch(\"exo.download.download_utils.EXO_MODELS_DIR\", models_dir):\n            result = await delete_model(model_id)\n\n            assert result is True\n            assert not await aios.path.exists(model_dir)\n            assert not await aios.path.exists(cache_dir)\n\n    async def test_delete_model_only_cache_exists(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"Test deleting when only cache exists (model already deleted).\"\"\"\n        models_dir = tmp_path / \"models\"\n        cache_dir = models_dir / \"caches\" / model_id.normalize()\n\n        # Only create cache directory\n        await aios.makedirs(cache_dir, exist_ok=True)\n        async with aiofiles.open(cache_dir / \"file_list.json\", \"w\") as f:\n            await f.write(\"[]\")\n\n        with patch(\"exo.download.download_utils.EXO_MODELS_DIR\", models_dir):\n            result = await delete_model(model_id)\n\n            # Returns False because model dir didn't exist\n            assert result is False\n            # But cache should still be cleaned up\n            assert not await aios.path.exists(cache_dir)\n\n    async def test_delete_nonexistent_model(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"Test deleting a model that doesn't exist.\"\"\"\n        models_dir = tmp_path / \"models\"\n        await aios.makedirs(models_dir, exist_ok=True)\n\n        with patch(\"exo.download.download_utils.EXO_MODELS_DIR\", models_dir):\n            result = await delete_model(model_id)\n\n            assert result is False\n\n\nclass TestProgressResetOnRedownload:\n    \"\"\"Tests for progress tracking when files are re-downloaded.\"\"\"\n\n    async def test_progress_resets_correctly_on_redownload(\n        self, model_id: ModelId\n    ) -> None:\n        \"\"\"Test that progress tracking resets when a file is re-downloaded.\n\n        When a file is deleted and re-downloaded (due to size mismatch),\n        the progress tracking should reset rather than calculating negative\n        downloaded_this_session values.\n        \"\"\"\n        # Simulate file_progress dict as it exists in download_shard\n        file_progress: dict[str, RepoFileDownloadProgress] = {}\n\n        # Initialize with old file progress (simulating existing large file)\n        old_file_size = 1_500_000_000  # 1.5 GB\n        file_progress[\"model.safetensors\"] = RepoFileDownloadProgress(\n            repo_id=model_id,\n            repo_revision=\"main\",\n            file_path=\"model.safetensors\",\n            downloaded=Memory.from_bytes(old_file_size),\n            downloaded_this_session=Memory.from_bytes(0),\n            total=Memory.from_bytes(old_file_size),\n            speed=0,\n            eta=timedelta(0),\n            status=\"not_started\",\n            start_time=time.time() - 10,  # Started 10 seconds ago\n        )\n\n        # Simulate the logic from on_progress_wrapper after re-download starts\n        # This is the exact logic from the fixed on_progress_wrapper\n        curr_bytes = 100_000  # 100 KB - new download just started\n        previous_progress = file_progress.get(\"model.safetensors\")\n\n        # Detect re-download: curr_bytes < previous downloaded\n        is_redownload = (\n            previous_progress is not None\n            and curr_bytes < previous_progress.downloaded.in_bytes\n        )\n\n        if is_redownload or previous_progress is None:\n            # Fresh download or re-download: reset tracking\n            start_time = time.time()\n            downloaded_this_session = curr_bytes\n        else:\n            # Continuing download: accumulate\n            start_time = previous_progress.start_time\n            downloaded_this_session = (\n                previous_progress.downloaded_this_session.in_bytes\n                + (curr_bytes - previous_progress.downloaded.in_bytes)\n            )\n\n        # Key assertions\n        assert is_redownload is True, \"Should detect re-download scenario\"\n        assert downloaded_this_session == curr_bytes, (\n            \"downloaded_this_session should equal curr_bytes on re-download\"\n        )\n        assert downloaded_this_session > 0, (\n            \"downloaded_this_session should be positive, not negative\"\n        )\n\n        # Calculate speed (should be positive)\n        elapsed = time.time() - start_time\n        speed = downloaded_this_session / elapsed if elapsed > 0 else 0\n        assert speed >= 0, \"Speed should be non-negative\"\n\n    async def test_progress_accumulates_on_continuing_download(\n        self, model_id: ModelId\n    ) -> None:\n        \"\"\"Test that progress accumulates correctly for continuing downloads.\n\n        When a download continues from where it left off (resume),\n        the progress should accumulate correctly.\n        \"\"\"\n        file_progress: dict[str, RepoFileDownloadProgress] = {}\n\n        # Initialize with partial download progress\n        initial_downloaded = 500_000  # 500 KB already downloaded\n        start_time = time.time() - 5  # Started 5 seconds ago\n        file_progress[\"model.safetensors\"] = RepoFileDownloadProgress(\n            repo_id=model_id,\n            repo_revision=\"main\",\n            file_path=\"model.safetensors\",\n            downloaded=Memory.from_bytes(initial_downloaded),\n            downloaded_this_session=Memory.from_bytes(initial_downloaded),\n            total=Memory.from_bytes(1_000_000),\n            speed=100_000,\n            eta=timedelta(seconds=5),\n            status=\"in_progress\",\n            start_time=start_time,\n        )\n\n        # Progress callback with more bytes downloaded\n        curr_bytes = 600_000  # 600 KB - continuing download\n        previous_progress = file_progress.get(\"model.safetensors\")\n\n        # This is NOT a re-download (curr_bytes > previous downloaded)\n        is_redownload = (\n            previous_progress is not None\n            and curr_bytes < previous_progress.downloaded.in_bytes\n        )\n\n        if is_redownload or previous_progress is None:\n            downloaded_this_session = curr_bytes\n            used_start_time = time.time()\n        else:\n            used_start_time = previous_progress.start_time\n            downloaded_this_session = (\n                previous_progress.downloaded_this_session.in_bytes\n                + (curr_bytes - previous_progress.downloaded.in_bytes)\n            )\n\n        # Key assertions\n        assert is_redownload is False, (\n            \"Should NOT detect re-download for continuing download\"\n        )\n        assert used_start_time == start_time, \"Should preserve original start_time\"\n        expected_session = initial_downloaded + (curr_bytes - initial_downloaded)\n        assert downloaded_this_session == expected_session, (\n            f\"Should accumulate: {downloaded_this_session} == {expected_session}\"\n        )\n        assert downloaded_this_session == 600_000, (\n            \"downloaded_this_session should equal total downloaded so far\"\n        )\n"
  },
  {
    "path": "src/exo/download/tests/test_offline_mode.py",
    "content": "\"\"\"Tests for offline/air-gapped mode.\"\"\"\n\nfrom collections.abc import AsyncIterator\nfrom pathlib import Path\nfrom unittest.mock import AsyncMock, patch\n\nimport aiofiles\nimport aiofiles.os as aios\nimport pytest\n\nfrom exo.download.download_utils import (\n    _download_file,  # pyright: ignore[reportPrivateUsage]\n    download_file_with_retry,\n    fetch_file_list_with_cache,\n)\nfrom exo.shared.types.common import ModelId\nfrom exo.shared.types.worker.downloads import FileListEntry\n\n\n@pytest.fixture\ndef model_id() -> ModelId:\n    return ModelId(\"test-org/test-model\")\n\n\n@pytest.fixture\nasync def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:\n    models_dir = tmp_path / \"models\"\n    await aios.makedirs(models_dir, exist_ok=True)\n    with patch(\"exo.download.download_utils.EXO_MODELS_DIR\", models_dir):\n        yield models_dir\n\n\nclass TestDownloadFileOffline:\n    \"\"\"Tests for _download_file with skip_internet=True.\"\"\"\n\n    async def test_returns_local_file_without_http_verification(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"When skip_internet=True and file exists locally, return it immediately\n        without making any HTTP calls (no file_meta verification).\"\"\"\n        target_dir = tmp_path / \"downloads\"\n        await aios.makedirs(target_dir, exist_ok=True)\n\n        local_file = target_dir / \"model.safetensors\"\n        async with aiofiles.open(local_file, \"wb\") as f:\n            await f.write(b\"model weights data\")\n\n        with patch(\n            \"exo.download.download_utils.file_meta\",\n            new_callable=AsyncMock,\n        ) as mock_file_meta:\n            result = await _download_file(\n                model_id,\n                \"main\",\n                \"model.safetensors\",\n                target_dir,\n                skip_internet=True,\n            )\n\n            assert result == local_file\n            mock_file_meta.assert_not_called()\n\n    async def test_raises_file_not_found_for_missing_file(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"When skip_internet=True and file does NOT exist locally,\n        raise FileNotFoundError instead of attempting download.\"\"\"\n        target_dir = tmp_path / \"downloads\"\n        await aios.makedirs(target_dir, exist_ok=True)\n\n        with pytest.raises(FileNotFoundError, match=\"offline mode\"):\n            await _download_file(\n                model_id,\n                \"main\",\n                \"missing_model.safetensors\",\n                target_dir,\n                skip_internet=True,\n            )\n\n    async def test_returns_local_file_in_subdirectory(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"When skip_internet=True and file exists in a subdirectory,\n        return it without HTTP calls.\"\"\"\n        target_dir = tmp_path / \"downloads\"\n        subdir = target_dir / \"transformer\"\n        await aios.makedirs(subdir, exist_ok=True)\n\n        local_file = subdir / \"diffusion_pytorch_model.safetensors\"\n        async with aiofiles.open(local_file, \"wb\") as f:\n            await f.write(b\"weights\")\n\n        with patch(\n            \"exo.download.download_utils.file_meta\",\n            new_callable=AsyncMock,\n        ) as mock_file_meta:\n            result = await _download_file(\n                model_id,\n                \"main\",\n                \"transformer/diffusion_pytorch_model.safetensors\",\n                target_dir,\n                skip_internet=True,\n            )\n\n            assert result == local_file\n            mock_file_meta.assert_not_called()\n\n\nclass TestDownloadFileWithRetryOffline:\n    \"\"\"Tests for download_file_with_retry with skip_internet=True.\"\"\"\n\n    async def test_propagates_skip_internet_to_download_file(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"Verify skip_internet is passed through to _download_file.\"\"\"\n        target_dir = tmp_path / \"downloads\"\n        await aios.makedirs(target_dir, exist_ok=True)\n\n        local_file = target_dir / \"config.json\"\n        async with aiofiles.open(local_file, \"wb\") as f:\n            await f.write(b'{\"model_type\": \"qwen2\"}')\n\n        with patch(\n            \"exo.download.download_utils.file_meta\",\n            new_callable=AsyncMock,\n        ) as mock_file_meta:\n            result = await download_file_with_retry(\n                model_id,\n                \"main\",\n                \"config.json\",\n                target_dir,\n                skip_internet=True,\n            )\n\n            assert result == local_file\n            mock_file_meta.assert_not_called()\n\n    async def test_file_not_found_does_not_retry(\n        self, model_id: ModelId, tmp_path: Path\n    ) -> None:\n        \"\"\"FileNotFoundError from offline mode should not trigger retries.\"\"\"\n        target_dir = tmp_path / \"downloads\"\n        await aios.makedirs(target_dir, exist_ok=True)\n\n        with pytest.raises(FileNotFoundError):\n            await download_file_with_retry(\n                model_id,\n                \"main\",\n                \"nonexistent.safetensors\",\n                target_dir,\n                skip_internet=True,\n            )\n\n\nclass TestFetchFileListOffline:\n    \"\"\"Tests for fetch_file_list_with_cache with skip_internet=True.\"\"\"\n\n    async def test_uses_cached_file_list(\n        self, model_id: ModelId, temp_models_dir: Path\n    ) -> None:\n        \"\"\"When skip_internet=True and cache file exists, use it without network.\"\"\"\n        from pydantic import TypeAdapter\n\n        cache_dir = temp_models_dir / \"caches\" / model_id.normalize()\n        await aios.makedirs(cache_dir, exist_ok=True)\n\n        cached_list = [\n            FileListEntry(type=\"file\", path=\"model.safetensors\", size=1000),\n            FileListEntry(type=\"file\", path=\"config.json\", size=200),\n        ]\n        cache_file = cache_dir / f\"{model_id.normalize()}--main--file_list.json\"\n        async with aiofiles.open(cache_file, \"w\") as f:\n            await f.write(\n                TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode()\n            )\n\n        with patch(\n            \"exo.download.download_utils.fetch_file_list_with_retry\",\n            new_callable=AsyncMock,\n        ) as mock_fetch:\n            result = await fetch_file_list_with_cache(\n                model_id, \"main\", skip_internet=True\n            )\n\n            assert result == cached_list\n            mock_fetch.assert_not_called()\n\n    async def test_falls_back_to_local_directory_scan(\n        self, model_id: ModelId, temp_models_dir: Path\n    ) -> None:\n        \"\"\"When skip_internet=True and no cache but local files exist,\n        build file list from local directory.\"\"\"\n        import json\n\n        model_dir = temp_models_dir / model_id.normalize()\n        await aios.makedirs(model_dir, exist_ok=True)\n\n        async with aiofiles.open(model_dir / \"config.json\", \"w\") as f:\n            await f.write('{\"model_type\": \"qwen2\"}')\n\n        index_data = {\n            \"metadata\": {},\n            \"weight_map\": {\"model.layers.0.weight\": \"model.safetensors\"},\n        }\n        async with aiofiles.open(model_dir / \"model.safetensors.index.json\", \"w\") as f:\n            await f.write(json.dumps(index_data))\n\n        async with aiofiles.open(model_dir / \"model.safetensors\", \"wb\") as f:\n            await f.write(b\"x\" * 500)\n\n        with patch(\n            \"exo.download.download_utils.fetch_file_list_with_retry\",\n            new_callable=AsyncMock,\n        ) as mock_fetch:\n            result = await fetch_file_list_with_cache(\n                model_id, \"main\", skip_internet=True\n            )\n\n            mock_fetch.assert_not_called()\n            paths = {entry.path for entry in result}\n            assert \"config.json\" in paths\n            assert \"model.safetensors\" in paths\n\n    async def test_raises_when_no_cache_and_no_local_files(\n        self, model_id: ModelId, temp_models_dir: Path\n    ) -> None:\n        \"\"\"When skip_internet=True and neither cache nor local files exist,\n        raise FileNotFoundError.\"\"\"\n        with pytest.raises(FileNotFoundError, match=\"No internet\"):\n            await fetch_file_list_with_cache(model_id, \"main\", skip_internet=True)\n"
  },
  {
    "path": "src/exo/download/tests/test_re_download.py",
    "content": "\"\"\"Tests that re-downloading a previously deleted model completes successfully.\"\"\"\n\nimport asyncio\nimport contextlib\nfrom collections.abc import AsyncIterator, Awaitable\nfrom datetime import timedelta\nfrom pathlib import Path\nfrom typing import Callable\nfrom unittest.mock import AsyncMock, patch\n\nfrom exo.download.coordinator import DownloadCoordinator\nfrom exo.download.download_utils import RepoDownloadProgress\nfrom exo.download.impl_shard_downloader import SingletonShardDownloader\nfrom exo.download.shard_downloader import ShardDownloader\nfrom exo.shared.models.model_cards import ModelCard, ModelId, ModelTask\nfrom exo.shared.types.commands import (\n    DeleteDownload,\n    ForwarderDownloadCommand,\n    StartDownload,\n)\nfrom exo.shared.types.common import NodeId, SystemId\nfrom exo.shared.types.events import Event, NodeDownloadProgress\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.worker.downloads import DownloadCompleted\nfrom exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata\nfrom exo.utils.channels import Receiver, Sender, channel\n\nNODE_ID = NodeId(\"aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa\")\nMODEL_ID = ModelId(\"test-org/test-model\")\n\n\ndef _make_shard(model_id: ModelId = MODEL_ID) -> ShardMetadata:\n    return PipelineShardMetadata(\n        model_card=ModelCard(\n            model_id=model_id,\n            storage_size=Memory.from_mb(100),\n            n_layers=28,\n            hidden_size=1024,\n            supports_tensor=False,\n            tasks=[ModelTask.TextGeneration],\n        ),\n        device_rank=0,\n        world_size=1,\n        start_layer=0,\n        end_layer=28,\n        n_layers=28,\n    )\n\n\nclass FakeShardDownloader(ShardDownloader):\n    \"\"\"Fake downloader that simulates a successful download by firing the\n    progress callback with status='complete' when ensure_shard is called.\"\"\"\n\n    def __init__(self) -> None:\n        self._progress_callbacks: list[\n            Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]]\n        ] = []\n\n    def on_progress(\n        self,\n        callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],\n    ) -> None:\n        self._progress_callbacks.append(callback)\n\n    async def ensure_shard(\n        self,\n        shard: ShardMetadata,\n        config_only: bool = False,  # noqa: ARG002\n    ) -> Path:\n        # Simulate a completed download by firing the progress callback\n        progress = RepoDownloadProgress(\n            repo_id=str(shard.model_card.model_id),\n            repo_revision=\"main\",\n            shard=shard,\n            completed_files=1,\n            total_files=1,\n            downloaded=Memory.from_mb(100),\n            downloaded_this_session=Memory.from_mb(100),\n            total=Memory.from_mb(100),\n            overall_speed=0,\n            overall_eta=timedelta(seconds=0),\n            status=\"complete\",\n        )\n        for cb in self._progress_callbacks:\n            await cb(shard, progress)\n        return Path(\"/fake/models\") / shard.model_card.model_id.normalize()\n\n    async def get_shard_download_status(\n        self,\n    ) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:\n        if False:  # noqa: SIM108  # empty async generator\n            yield (\n                Path(),\n                RepoDownloadProgress(  # pyright: ignore[reportUnreachable]\n                    repo_id=\"\",\n                    repo_revision=\"\",\n                    shard=_make_shard(),\n                    completed_files=0,\n                    total_files=0,\n                    downloaded=Memory.from_bytes(0),\n                    downloaded_this_session=Memory.from_bytes(0),\n                    total=Memory.from_bytes(0),\n                    overall_speed=0,\n                    overall_eta=timedelta(seconds=0),\n                    status=\"not_started\",\n                ),\n            )\n\n    async def get_shard_download_status_for_shard(\n        self,\n        shard: ShardMetadata,\n    ) -> RepoDownloadProgress:\n        return RepoDownloadProgress(\n            repo_id=str(shard.model_card.model_id),\n            repo_revision=\"main\",\n            shard=shard,\n            completed_files=0,\n            total_files=1,\n            downloaded=Memory.from_bytes(0),\n            downloaded_this_session=Memory.from_bytes(0),\n            total=Memory.from_mb(100),\n            overall_speed=0,\n            overall_eta=timedelta(seconds=0),\n            status=\"not_started\",\n        )\n\n\nasync def test_re_download_after_delete_completes() -> None:\n    \"\"\"A model that was downloaded, deleted, and then re-downloaded should\n    reach DownloadCompleted status. This is an end-to-end test through\n    the DownloadCoordinator.\"\"\"\n    cmd_send: Sender[ForwarderDownloadCommand]\n    cmd_send, cmd_recv = channel[ForwarderDownloadCommand]()\n    event_send, event_recv = channel[Event]()\n\n    fake_downloader = FakeShardDownloader()\n    wrapped_downloader = SingletonShardDownloader(fake_downloader)\n    coordinator = DownloadCoordinator(\n        node_id=NODE_ID,\n        shard_downloader=wrapped_downloader,\n        download_command_receiver=cmd_recv,\n        event_sender=event_send,\n    )\n\n    shard = _make_shard()\n    origin = SystemId(\"test\")\n\n    with patch(\"exo.download.coordinator.delete_model\", new_callable=AsyncMock):\n        # Run the coordinator in the background\n        coordinator_task = asyncio.create_task(coordinator.run())\n\n        try:\n            # 1. Start first download\n            await cmd_send.send(\n                ForwarderDownloadCommand(\n                    origin=origin,\n                    command=StartDownload(target_node_id=NODE_ID, shard_metadata=shard),\n                )\n            )\n\n            # Wait for DownloadCompleted\n            first_completed = await _wait_for_download_completed(event_recv, MODEL_ID)\n            assert first_completed is not None, \"First download should complete\"\n\n            # 2. Delete the model\n            await cmd_send.send(\n                ForwarderDownloadCommand(\n                    origin=origin,\n                    command=DeleteDownload(target_node_id=NODE_ID, model_id=MODEL_ID),\n                )\n            )\n            # Give the coordinator time to process the delete\n            await asyncio.sleep(0.05)\n\n            # 3. Re-download the same model\n            await cmd_send.send(\n                ForwarderDownloadCommand(\n                    origin=origin,\n                    command=StartDownload(target_node_id=NODE_ID, shard_metadata=shard),\n                )\n            )\n\n            # Wait for second DownloadCompleted — this is the bug: it never arrives\n            second_completed = await _wait_for_download_completed(event_recv, MODEL_ID)\n            assert second_completed is not None, (\n                \"Re-download after deletion should complete\"\n            )\n        finally:\n            coordinator.shutdown()\n            coordinator_task.cancel()\n            with contextlib.suppress(asyncio.CancelledError):\n                await coordinator_task\n\n\nasync def _wait_for_download_completed(\n    event_recv: Receiver[Event], model_id: ModelId, timeout: float = 2.0\n) -> DownloadCompleted | None:\n    \"\"\"Drain events until we see a DownloadCompleted for the given model, or timeout.\"\"\"\n    try:\n        async with asyncio.timeout(timeout):\n            while True:\n                event = await event_recv.receive()\n                if (\n                    isinstance(event, NodeDownloadProgress)\n                    and isinstance(event.download_progress, DownloadCompleted)\n                    and event.download_progress.shard_metadata.model_card.model_id\n                    == model_id\n                ):\n                    return event.download_progress\n    except TimeoutError:\n        return None\n"
  },
  {
    "path": "src/exo/main.py",
    "content": "import argparse\nimport multiprocessing as mp\nimport os\nimport resource\nimport signal\nfrom dataclasses import dataclass, field\nfrom typing import Self\n\nimport anyio\nfrom loguru import logger\nfrom pydantic import PositiveInt\n\nimport exo.routing.topics as topics\nfrom exo.api.main import API\nfrom exo.download.coordinator import DownloadCoordinator\nfrom exo.download.impl_shard_downloader import exo_shard_downloader\nfrom exo.master.main import Master\nfrom exo.routing.event_router import EventRouter\nfrom exo.routing.router import Router, get_node_id_keypair\nfrom exo.shared.constants import EXO_LOG\nfrom exo.shared.election import Election, ElectionResult\nfrom exo.shared.logging import logger_cleanup, logger_setup\nfrom exo.shared.types.common import NodeId, SessionId\nfrom exo.utils.channels import Receiver, channel\nfrom exo.utils.pydantic_ext import CamelCaseModel\nfrom exo.utils.task_group import TaskGroup\nfrom exo.worker.main import Worker\n\n\n@dataclass\nclass Node:\n    router: Router\n    event_router: EventRouter\n    download_coordinator: DownloadCoordinator | None\n    worker: Worker | None\n    election: Election  # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.\n    election_result_receiver: Receiver[ElectionResult]\n    master: Master | None\n    api: API | None\n\n    node_id: NodeId\n    offline: bool\n    _tg: TaskGroup = field(init=False, default_factory=TaskGroup)\n\n    @classmethod\n    async def create(cls, args: \"Args\") -> Self:\n        keypair = get_node_id_keypair()\n        node_id = NodeId(keypair.to_node_id())\n        session_id = SessionId(master_node_id=node_id, election_clock=0)\n        router = Router.create(keypair)\n        await router.register_topic(topics.GLOBAL_EVENTS)\n        await router.register_topic(topics.LOCAL_EVENTS)\n        await router.register_topic(topics.COMMANDS)\n        await router.register_topic(topics.ELECTION_MESSAGES)\n        await router.register_topic(topics.CONNECTION_MESSAGES)\n        await router.register_topic(topics.DOWNLOAD_COMMANDS)\n        event_router = EventRouter(\n            session_id,\n            command_sender=router.sender(topics.COMMANDS),\n            external_outbound=router.sender(topics.LOCAL_EVENTS),\n            external_inbound=router.receiver(topics.GLOBAL_EVENTS),\n        )\n\n        logger.info(f\"Starting node {node_id}\")\n\n        # Create DownloadCoordinator (unless --no-downloads)\n        if not args.no_downloads:\n            download_coordinator = DownloadCoordinator(\n                node_id,\n                exo_shard_downloader(offline=args.offline),\n                event_sender=event_router.sender(),\n                download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),\n                offline=args.offline,\n            )\n        else:\n            download_coordinator = None\n\n        if args.spawn_api:\n            api = API(\n                node_id,\n                port=args.api_port,\n                event_receiver=event_router.receiver(),\n                command_sender=router.sender(topics.COMMANDS),\n                download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),\n                election_receiver=router.receiver(topics.ELECTION_MESSAGES),\n            )\n        else:\n            api = None\n\n        if not args.no_worker:\n            worker = Worker(\n                node_id,\n                event_receiver=event_router.receiver(),\n                event_sender=event_router.sender(),\n                command_sender=router.sender(topics.COMMANDS),\n                download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),\n            )\n        else:\n            worker = None\n\n        # We start every node with a master\n        master = Master(\n            node_id,\n            session_id,\n            event_sender=event_router.sender(),\n            global_event_sender=router.sender(topics.GLOBAL_EVENTS),\n            local_event_receiver=router.receiver(topics.LOCAL_EVENTS),\n            command_receiver=router.receiver(topics.COMMANDS),\n            download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),\n        )\n\n        er_send, er_recv = channel[ElectionResult]()\n        election = Election(\n            node_id,\n            # If someone manages to assemble 1 MILLION devices into an exo cluster then. well done. good job champ.\n            seniority=1_000_000 if args.force_master else 0,\n            # nb: this DOES feedback right now. i have thoughts on how to address this,\n            # but ultimately it seems not worth the complexity\n            election_message_sender=router.sender(topics.ELECTION_MESSAGES),\n            election_message_receiver=router.receiver(topics.ELECTION_MESSAGES),\n            connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),\n            command_receiver=router.receiver(topics.COMMANDS),\n            election_result_sender=er_send,\n        )\n\n        return cls(\n            router,\n            event_router,\n            download_coordinator,\n            worker,\n            election,\n            er_recv,\n            master,\n            api,\n            node_id,\n            args.offline,\n        )\n\n    async def run(self):\n        async with self._tg as tg:\n            signal.signal(signal.SIGINT, lambda _, __: self.shutdown())\n            signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())\n            tg.start_soon(self.router.run)\n            tg.start_soon(self.event_router.run)\n            tg.start_soon(self.election.run)\n            if self.download_coordinator:\n                tg.start_soon(self.download_coordinator.run)\n            if self.worker:\n                tg.start_soon(self.worker.run)\n            if self.master:\n                tg.start_soon(self.master.run)\n            if self.api:\n                tg.start_soon(self.api.run)\n            tg.start_soon(self._elect_loop)\n\n    def shutdown(self):\n        # if this is our second call to shutdown, just sys.exit\n        if self._tg.cancel_called():\n            import sys\n\n            sys.exit(1)\n        self._tg.cancel_tasks()\n\n    async def _elect_loop(self):\n        with self.election_result_receiver as results:\n            async for result in results:\n                # This function continues to have a lot of very specific entangled logic\n                # At least it's somewhat contained\n\n                # I don't like this duplication, but it's manageable for now.\n                # TODO: This function needs refactoring generally\n\n                # Ok:\n                # On new master:\n                # - Elect master locally if necessary\n                # - Shutdown and re-create the worker\n                # - Shut down and re-create the API\n\n                if result.is_new_master:\n                    await anyio.sleep(0)\n                    self.event_router.shutdown()\n                    self.event_router = EventRouter(\n                        result.session_id,\n                        self.router.sender(topics.COMMANDS),\n                        self.router.receiver(topics.GLOBAL_EVENTS),\n                        self.router.sender(topics.LOCAL_EVENTS),\n                    )\n                    self._tg.start_soon(self.event_router.run)\n\n                if (\n                    result.session_id.master_node_id == self.node_id\n                    and self.master is not None\n                ):\n                    logger.info(\"Node elected Master\")\n                elif (\n                    result.session_id.master_node_id == self.node_id\n                    and self.master is None\n                ):\n                    logger.info(\"Node elected Master - promoting self\")\n                    self.master = Master(\n                        self.node_id,\n                        result.session_id,\n                        event_sender=self.event_router.sender(),\n                        global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),\n                        local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),\n                        command_receiver=self.router.receiver(topics.COMMANDS),\n                        download_command_sender=self.router.sender(\n                            topics.DOWNLOAD_COMMANDS\n                        ),\n                    )\n                    self._tg.start_soon(self.master.run)\n                elif (\n                    result.session_id.master_node_id != self.node_id\n                    and self.master is not None\n                ):\n                    logger.info(\n                        f\"Node {result.session_id.master_node_id} elected master - demoting self\"\n                    )\n                    await self.master.shutdown()\n                    self.master = None\n                else:\n                    logger.info(\n                        f\"Node {result.session_id.master_node_id} elected master\"\n                    )\n                if result.is_new_master:\n                    if self.download_coordinator:\n                        self.download_coordinator.shutdown()\n                        self.download_coordinator = DownloadCoordinator(\n                            self.node_id,\n                            exo_shard_downloader(offline=self.offline),\n                            event_sender=self.event_router.sender(),\n                            download_command_receiver=self.router.receiver(\n                                topics.DOWNLOAD_COMMANDS\n                            ),\n                            offline=self.offline,\n                        )\n                        self._tg.start_soon(self.download_coordinator.run)\n                    if self.worker:\n                        self.worker.shutdown()\n                        # TODO: add profiling etc to resource monitor\n                        self.worker = Worker(\n                            self.node_id,\n                            event_receiver=self.event_router.receiver(),\n                            event_sender=self.event_router.sender(),\n                            command_sender=self.router.sender(topics.COMMANDS),\n                            download_command_sender=self.router.sender(\n                                topics.DOWNLOAD_COMMANDS\n                            ),\n                        )\n                        self._tg.start_soon(self.worker.run)\n                    if self.api:\n                        self.api.reset(result.won_clock, self.event_router.receiver())\n                else:\n                    if self.api:\n                        self.api.unpause(result.won_clock)\n\n\ndef main():\n    args = Args.parse()\n    soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)\n    target = min(max(soft, 65535), hard)\n    resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))\n\n    mp.set_start_method(\"spawn\", force=True)\n    # TODO: Refactor the current verbosity system\n    logger_setup(EXO_LOG, args.verbosity)\n    logger.info(\"Starting EXO\")\n    logger.info(f\"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}\")\n\n    if args.offline:\n        logger.info(\"Running in OFFLINE mode — no internet checks, local models only\")\n\n    if args.no_batch:\n        os.environ[\"EXO_NO_BATCH\"] = \"1\"\n        logger.info(\"Continuous batching disabled (--no-batch)\")\n\n    # Set FAST_SYNCH override env var for runner subprocesses\n    if args.fast_synch is True:\n        os.environ[\"EXO_FAST_SYNCH\"] = \"on\"\n        logger.info(\"FAST_SYNCH forced ON\")\n    elif args.fast_synch is False:\n        os.environ[\"EXO_FAST_SYNCH\"] = \"off\"\n        logger.info(\"FAST_SYNCH forced OFF\")\n\n    node = anyio.run(Node.create, args)\n    try:\n        anyio.run(node.run)\n    except BaseException as exception:\n        logger.opt(exception=exception).critical(\n            \"EXO terminated due to unhandled exception\"\n        )\n        raise\n    finally:\n        logger.info(\"EXO Shutdown complete\")\n        logger_cleanup()\n\n\nclass Args(CamelCaseModel):\n    verbosity: int = 0\n    force_master: bool = False\n    spawn_api: bool = False\n    api_port: PositiveInt = 52415\n    tb_only: bool = False\n    no_worker: bool = False\n    no_downloads: bool = False\n    offline: bool = os.getenv(\"EXO_OFFLINE\", \"false\").lower() == \"true\"\n    no_batch: bool = False\n    fast_synch: bool | None = None  # None = auto, True = force on, False = force off\n\n    @classmethod\n    def parse(cls) -> Self:\n        parser = argparse.ArgumentParser(prog=\"EXO\")\n        default_verbosity = 0\n        parser.add_argument(\n            \"-q\",\n            \"--quiet\",\n            action=\"store_const\",\n            const=-1,\n            dest=\"verbosity\",\n            default=default_verbosity,\n        )\n        parser.add_argument(\n            \"-v\",\n            \"--verbose\",\n            action=\"count\",\n            dest=\"verbosity\",\n            default=default_verbosity,\n        )\n        parser.add_argument(\n            \"-m\",\n            \"--force-master\",\n            action=\"store_true\",\n            dest=\"force_master\",\n        )\n        parser.add_argument(\n            \"--no-api\",\n            action=\"store_false\",\n            dest=\"spawn_api\",\n        )\n        parser.add_argument(\n            \"--api-port\",\n            type=int,\n            dest=\"api_port\",\n            default=52415,\n        )\n        parser.add_argument(\n            \"--no-worker\",\n            action=\"store_true\",\n        )\n        parser.add_argument(\n            \"--no-downloads\",\n            action=\"store_true\",\n            help=\"Disable the download coordinator (node won't download models)\",\n        )\n        parser.add_argument(\n            \"--offline\",\n            action=\"store_true\",\n            default=os.getenv(\"EXO_OFFLINE\", \"false\").lower() == \"true\",\n            help=\"Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models\",\n        )\n        parser.add_argument(\n            \"--no-batch\",\n            action=\"store_true\",\n            help=\"Disable continuous batching, use sequential generation\",\n        )\n        fast_synch_group = parser.add_mutually_exclusive_group()\n        fast_synch_group.add_argument(\n            \"--fast-synch\",\n            action=\"store_true\",\n            dest=\"fast_synch\",\n            default=None,\n            help=\"Force MLX FAST_SYNCH on (for JACCL backend)\",\n        )\n        fast_synch_group.add_argument(\n            \"--no-fast-synch\",\n            action=\"store_false\",\n            dest=\"fast_synch\",\n            help=\"Force MLX FAST_SYNCH off\",\n        )\n\n        args = parser.parse_args()\n        return cls(**vars(args))  # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically\n"
  },
  {
    "path": "src/exo/master/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/master/image_store.py",
    "content": "import time\nfrom pathlib import Path\n\nfrom pydantic import BaseModel\n\nfrom exo.shared.types.common import Id\n\n\nclass StoredImage(BaseModel, frozen=True):\n    image_id: Id\n    file_path: Path\n    content_type: str\n    expires_at: float\n\n\nclass ImageStore:\n    def __init__(self, storage_dir: Path, default_expiry_seconds: int = 3600) -> None:\n        self._storage_dir = storage_dir\n        self._default_expiry_seconds = default_expiry_seconds\n        self._images: dict[Id, StoredImage] = {}\n        self._storage_dir.mkdir(parents=True, exist_ok=True)\n\n    def store(self, image_bytes: bytes, content_type: str) -> StoredImage:\n        image_id = Id()\n        extension = _content_type_to_extension(content_type)\n        file_path = self._storage_dir / f\"{image_id}{extension}\"\n        file_path.write_bytes(image_bytes)\n\n        stored = StoredImage(\n            image_id=image_id,\n            file_path=file_path,\n            content_type=content_type,\n            expires_at=time.time() + self._default_expiry_seconds,\n        )\n        self._images[image_id] = stored\n        return stored\n\n    def get(self, image_id: Id) -> StoredImage | None:\n        stored = self._images.get(image_id)\n        if stored is None:\n            return None\n\n        if time.time() > stored.expires_at:\n            self._remove(image_id)\n            return None\n\n        return stored\n\n    def list_images(self) -> list[StoredImage]:\n        now = time.time()\n        return [stored for stored in self._images.values() if now <= stored.expires_at]\n\n    def cleanup_expired(self) -> int:\n        now = time.time()\n        expired_ids = [\n            image_id\n            for image_id, stored in self._images.items()\n            if now > stored.expires_at\n        ]\n\n        for image_id in expired_ids:\n            self._remove(image_id)\n\n        return len(expired_ids)\n\n    def _remove(self, image_id: Id) -> None:\n        stored = self._images.pop(image_id, None)\n        if stored is not None and stored.file_path.exists():\n            stored.file_path.unlink()\n\n\ndef _content_type_to_extension(\n    content_type: str,\n) -> str:\n    ext = f\"{content_type.split('/')[1]}\"\n    if ext == \"jpeg\":\n        ext = \"jpg\"\n\n    return f\".{ext}\"\n"
  },
  {
    "path": "src/exo/master/main.py",
    "content": "from datetime import datetime, timedelta, timezone\n\nimport anyio\nfrom loguru import logger\n\nfrom exo.master.placement import (\n    add_instance_to_placements,\n    cancel_unnecessary_downloads,\n    delete_instance,\n    get_transition_events,\n    place_instance,\n)\nfrom exo.shared.apply import apply\nfrom exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED\nfrom exo.shared.types.commands import (\n    CreateInstance,\n    DeleteInstance,\n    ForwarderCommand,\n    ForwarderDownloadCommand,\n    ImageEdits,\n    ImageGeneration,\n    PlaceInstance,\n    RequestEventLog,\n    SendInputChunk,\n    TaskCancelled,\n    TaskFinished,\n    TestCommand,\n    TextGeneration,\n)\nfrom exo.shared.types.common import CommandId, NodeId, SessionId, SystemId\nfrom exo.shared.types.events import (\n    Event,\n    GlobalForwarderEvent,\n    IndexedEvent,\n    InputChunkReceived,\n    InstanceDeleted,\n    LocalForwarderEvent,\n    NodeGatheredInfo,\n    NodeTimedOut,\n    TaskCreated,\n    TaskDeleted,\n    TaskStatusUpdated,\n    TraceEventData,\n    TracesCollected,\n    TracesMerged,\n)\nfrom exo.shared.types.state import State\nfrom exo.shared.types.tasks import (\n    ImageEdits as ImageEditsTask,\n)\nfrom exo.shared.types.tasks import (\n    ImageGeneration as ImageGenerationTask,\n)\nfrom exo.shared.types.tasks import (\n    TaskId,\n    TaskStatus,\n)\nfrom exo.shared.types.tasks import (\n    TextGeneration as TextGenerationTask,\n)\nfrom exo.shared.types.worker.instances import InstanceId\nfrom exo.utils.channels import Receiver, Sender\nfrom exo.utils.disk_event_log import DiskEventLog\nfrom exo.utils.event_buffer import MultiSourceBuffer\nfrom exo.utils.task_group import TaskGroup\n\n\nclass Master:\n    def __init__(\n        self,\n        node_id: NodeId,\n        session_id: SessionId,\n        *,\n        command_receiver: Receiver[ForwarderCommand],\n        event_sender: Sender[Event],\n        local_event_receiver: Receiver[LocalForwarderEvent],\n        global_event_sender: Sender[GlobalForwarderEvent],\n        download_command_sender: Sender[ForwarderDownloadCommand],\n    ):\n        self.node_id = node_id\n        self.session_id = session_id\n        self.state = State()\n        self._tg: TaskGroup = TaskGroup()\n        self.command_task_mapping: dict[CommandId, TaskId] = {}\n        self.command_receiver = command_receiver\n        self.local_event_receiver = local_event_receiver\n        self.global_event_sender = global_event_sender\n        self.download_command_sender = download_command_sender\n        self.event_sender = event_sender\n        self._system_id = SystemId()\n        self._multi_buffer = MultiSourceBuffer[SystemId, Event]()\n        self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / \"master\")\n        self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}\n        self._expected_ranks: dict[TaskId, set[int]] = {}\n\n    async def run(self):\n        logger.info(\"Starting Master\")\n\n        try:\n            async with self._tg as tg:\n                tg.start_soon(self._event_processor)\n                tg.start_soon(self._command_processor)\n                tg.start_soon(self._plan)\n        finally:\n            self._event_log.close()\n            self.global_event_sender.close()\n            self.local_event_receiver.close()\n            self.command_receiver.close()\n\n    async def shutdown(self):\n        logger.info(\"Stopping Master\")\n        self._tg.cancel_tasks()\n\n    async def _command_processor(self) -> None:\n        with self.command_receiver as commands:\n            async for forwarder_command in commands:\n                try:\n                    logger.info(f\"Executing command: {forwarder_command.command}\")\n\n                    generated_events: list[Event] = []\n                    command = forwarder_command.command\n                    instance_task_counts: dict[InstanceId, int] = {}\n                    match command:\n                        case TestCommand():\n                            pass\n                        case TextGeneration():\n                            for instance in self.state.instances.values():\n                                if (\n                                    instance.shard_assignments.model_id\n                                    == command.task_params.model\n                                ):\n                                    task_count = sum(\n                                        1\n                                        for task in self.state.tasks.values()\n                                        if task.instance_id == instance.instance_id\n                                    )\n                                    instance_task_counts[instance.instance_id] = (\n                                        task_count\n                                    )\n\n                            if not instance_task_counts:\n                                raise ValueError(\n                                    f\"No instance found for model {command.task_params.model}\"\n                                )\n\n                            available_instance_ids = sorted(\n                                instance_task_counts.keys(),\n                                key=lambda instance_id: instance_task_counts[\n                                    instance_id\n                                ],\n                            )\n\n                            task_id = TaskId()\n                            generated_events.append(\n                                TaskCreated(\n                                    task_id=task_id,\n                                    task=TextGenerationTask(\n                                        task_id=task_id,\n                                        command_id=command.command_id,\n                                        instance_id=available_instance_ids[0],\n                                        task_status=TaskStatus.Pending,\n                                        task_params=command.task_params,\n                                    ),\n                                )\n                            )\n\n                            self.command_task_mapping[command.command_id] = task_id\n                        case ImageGeneration():\n                            for instance in self.state.instances.values():\n                                if (\n                                    instance.shard_assignments.model_id\n                                    == command.task_params.model\n                                ):\n                                    task_count = sum(\n                                        1\n                                        for task in self.state.tasks.values()\n                                        if task.instance_id == instance.instance_id\n                                    )\n                                    instance_task_counts[instance.instance_id] = (\n                                        task_count\n                                    )\n\n                            if not instance_task_counts:\n                                raise ValueError(\n                                    f\"No instance found for model {command.task_params.model}\"\n                                )\n\n                            available_instance_ids = sorted(\n                                instance_task_counts.keys(),\n                                key=lambda instance_id: instance_task_counts[\n                                    instance_id\n                                ],\n                            )\n\n                            task_id = TaskId()\n                            selected_instance_id = available_instance_ids[0]\n                            generated_events.append(\n                                TaskCreated(\n                                    task_id=task_id,\n                                    task=ImageGenerationTask(\n                                        task_id=task_id,\n                                        command_id=command.command_id,\n                                        instance_id=selected_instance_id,\n                                        task_status=TaskStatus.Pending,\n                                        task_params=command.task_params,\n                                    ),\n                                )\n                            )\n\n                            self.command_task_mapping[command.command_id] = task_id\n\n                            if EXO_TRACING_ENABLED:\n                                selected_instance = self.state.instances.get(\n                                    selected_instance_id\n                                )\n                                if selected_instance:\n                                    ranks = set(\n                                        shard.device_rank\n                                        for shard in selected_instance.shard_assignments.runner_to_shard.values()\n                                    )\n                                    self._expected_ranks[task_id] = ranks\n                        case ImageEdits():\n                            for instance in self.state.instances.values():\n                                if (\n                                    instance.shard_assignments.model_id\n                                    == command.task_params.model\n                                ):\n                                    task_count = sum(\n                                        1\n                                        for task in self.state.tasks.values()\n                                        if task.instance_id == instance.instance_id\n                                    )\n                                    instance_task_counts[instance.instance_id] = (\n                                        task_count\n                                    )\n\n                            if not instance_task_counts:\n                                raise ValueError(\n                                    f\"No instance found for model {command.task_params.model}\"\n                                )\n\n                            available_instance_ids = sorted(\n                                instance_task_counts.keys(),\n                                key=lambda instance_id: instance_task_counts[\n                                    instance_id\n                                ],\n                            )\n\n                            task_id = TaskId()\n                            selected_instance_id = available_instance_ids[0]\n                            generated_events.append(\n                                TaskCreated(\n                                    task_id=task_id,\n                                    task=ImageEditsTask(\n                                        task_id=task_id,\n                                        command_id=command.command_id,\n                                        instance_id=selected_instance_id,\n                                        task_status=TaskStatus.Pending,\n                                        task_params=command.task_params,\n                                    ),\n                                )\n                            )\n\n                            self.command_task_mapping[command.command_id] = task_id\n\n                            if EXO_TRACING_ENABLED:\n                                selected_instance = self.state.instances.get(\n                                    selected_instance_id\n                                )\n                                if selected_instance:\n                                    ranks = set(\n                                        shard.device_rank\n                                        for shard in selected_instance.shard_assignments.runner_to_shard.values()\n                                    )\n                                    self._expected_ranks[task_id] = ranks\n                        case DeleteInstance():\n                            placement = delete_instance(command, self.state.instances)\n                            transition_events = get_transition_events(\n                                self.state.instances, placement, self.state.tasks\n                            )\n                            for cmd in cancel_unnecessary_downloads(\n                                placement, self.state.downloads\n                            ):\n                                await self.download_command_sender.send(\n                                    ForwarderDownloadCommand(\n                                        origin=self._system_id, command=cmd\n                                    )\n                                )\n                            generated_events.extend(transition_events)\n                        case PlaceInstance():\n                            placement = place_instance(\n                                command,\n                                self.state.topology,\n                                self.state.instances,\n                                self.state.node_memory,\n                                self.state.node_network,\n                            )\n                            transition_events = get_transition_events(\n                                self.state.instances, placement, self.state.tasks\n                            )\n                            generated_events.extend(transition_events)\n                        case CreateInstance():\n                            placement = add_instance_to_placements(\n                                command,\n                                self.state.topology,\n                                self.state.instances,\n                            )\n                            transition_events = get_transition_events(\n                                self.state.instances, placement, self.state.tasks\n                            )\n                            generated_events.extend(transition_events)\n                        case SendInputChunk(chunk=chunk):\n                            generated_events.append(\n                                InputChunkReceived(\n                                    command_id=chunk.command_id,\n                                    chunk=chunk,\n                                )\n                            )\n                        case TaskCancelled():\n                            if (\n                                task_id := self.command_task_mapping.get(\n                                    command.cancelled_command_id\n                                )\n                            ) is not None:\n                                generated_events.append(\n                                    TaskStatusUpdated(\n                                        task_status=TaskStatus.Cancelled,\n                                        task_id=task_id,\n                                    )\n                                )\n                            else:\n                                logger.warning(\n                                    f\"Nonexistent command {command.cancelled_command_id} cancelled\"\n                                )\n                        case TaskFinished():\n                            if (\n                                task_id := self.command_task_mapping.pop(\n                                    command.finished_command_id, None\n                                )\n                            ) is not None:\n                                generated_events.append(TaskDeleted(task_id=task_id))\n                            else:\n                                logger.warning(\n                                    f\"Finished command {command.finished_command_id} finished\"\n                                )\n\n                        case RequestEventLog():\n                            # We should just be able to send everything, since other buffers will ignore old messages\n                            # rate limit to 1000 at a time\n                            end = min(command.since_idx + 1000, len(self._event_log))\n                            for i, event in enumerate(\n                                self._event_log.read_range(command.since_idx, end),\n                                start=command.since_idx,\n                            ):\n                                await self._send_event(IndexedEvent(idx=i, event=event))\n                    for event in generated_events:\n                        await self.event_sender.send(event)\n                except ValueError as e:\n                    logger.opt(exception=e).warning(\"Error in command processor\")\n\n    # These plan loops are the cracks showing in our event sourcing architecture - more things could be commands\n    async def _plan(self) -> None:\n        while True:\n            # kill broken instances\n            connected_node_ids = set(self.state.topology.list_nodes())\n            for instance_id, instance in self.state.instances.items():\n                for node_id in instance.shard_assignments.node_to_runner:\n                    if node_id not in connected_node_ids:\n                        await self.event_sender.send(\n                            InstanceDeleted(instance_id=instance_id)\n                        )\n                        break\n\n            # time out dead nodes\n            for node_id, time in self.state.last_seen.items():\n                now = datetime.now(tz=timezone.utc)\n                if now - time > timedelta(seconds=30):\n                    logger.info(f\"Manually removing node {node_id} due to inactivity\")\n                    await self.event_sender.send(NodeTimedOut(node_id=node_id))\n\n            await anyio.sleep(10)\n\n    async def _event_processor(self) -> None:\n        with self.local_event_receiver as local_events:\n            async for local_event in local_events:\n                # Discard all events not from our session\n                if local_event.session != self.session_id:\n                    continue\n                self._multi_buffer.ingest(\n                    local_event.origin_idx,\n                    local_event.event,\n                    local_event.origin,\n                )\n                for event in self._multi_buffer.drain():\n                    if isinstance(event, TracesCollected):\n                        await self._handle_traces_collected(event)\n                        continue\n\n                    logger.debug(f\"Master indexing event: {str(event)[:100]}\")\n                    indexed = IndexedEvent(event=event, idx=len(self._event_log))\n                    self.state = apply(self.state, indexed)\n\n                    event._master_time_stamp = datetime.now(tz=timezone.utc)  # pyright: ignore[reportPrivateUsage]\n                    if isinstance(event, NodeGatheredInfo):\n                        event.when = str(datetime.now(tz=timezone.utc))\n\n                    self._event_log.append(event)\n                    await self._send_event(indexed)\n\n    # This function is re-entrant, take care!\n    async def _send_event(self, event: IndexedEvent):\n        # Convenience method since this line is ugly\n        await self.global_event_sender.send(\n            GlobalForwarderEvent(\n                origin=self.node_id,\n                origin_idx=event.idx,\n                session=self.session_id,\n                event=event.event,\n            )\n        )\n\n    async def _handle_traces_collected(self, event: TracesCollected) -> None:\n        task_id = event.task_id\n        if task_id not in self._pending_traces:\n            self._pending_traces[task_id] = {}\n        self._pending_traces[task_id][event.rank] = event.traces\n\n        if (\n            task_id in self._expected_ranks\n            and set(self._pending_traces[task_id].keys())\n            >= self._expected_ranks[task_id]\n        ):\n            await self._merge_and_save_traces(task_id)\n\n    async def _merge_and_save_traces(self, task_id: TaskId) -> None:\n        all_trace_data: list[TraceEventData] = []\n        for trace_data in self._pending_traces[task_id].values():\n            all_trace_data.extend(trace_data)\n\n        await self.event_sender.send(\n            TracesMerged(task_id=task_id, traces=all_trace_data)\n        )\n\n        del self._pending_traces[task_id]\n        if task_id in self._expected_ranks:\n            del self._expected_ranks[task_id]\n"
  },
  {
    "path": "src/exo/master/placement.py",
    "content": "import random\nfrom collections.abc import Mapping\nfrom copy import deepcopy\nfrom typing import Sequence\n\nfrom exo.master.placement_utils import (\n    Cycle,\n    filter_cycles_by_memory,\n    get_mlx_jaccl_coordinators,\n    get_mlx_jaccl_devices_matrix,\n    get_mlx_ring_hosts_by_node,\n    get_shard_assignments,\n    get_smallest_cycles,\n)\nfrom exo.shared.models.model_cards import ModelId\nfrom exo.shared.topology import Topology\nfrom exo.shared.types.commands import (\n    CancelDownload,\n    CreateInstance,\n    DeleteInstance,\n    DownloadCommand,\n    PlaceInstance,\n)\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.events import (\n    Event,\n    InstanceCreated,\n    InstanceDeleted,\n    TaskStatusUpdated,\n)\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo\nfrom exo.shared.types.tasks import Task, TaskId, TaskStatus\nfrom exo.shared.types.worker.downloads import (\n    DownloadOngoing,\n    DownloadProgress,\n)\nfrom exo.shared.types.worker.instances import (\n    Instance,\n    InstanceId,\n    InstanceMeta,\n    MlxJacclInstance,\n    MlxRingInstance,\n)\nfrom exo.shared.types.worker.shards import Sharding\n\n\ndef random_ephemeral_port() -> int:\n    port = random.randint(49153, 65535)\n    return port - 1 if port <= 52415 else port\n\n\ndef add_instance_to_placements(\n    command: CreateInstance,\n    topology: Topology,\n    current_instances: Mapping[InstanceId, Instance],\n) -> Mapping[InstanceId, Instance]:\n    # TODO: validate against topology\n\n    return {**current_instances, command.instance.instance_id: command.instance}\n\n\ndef place_instance(\n    command: PlaceInstance,\n    topology: Topology,\n    current_instances: Mapping[InstanceId, Instance],\n    node_memory: Mapping[NodeId, MemoryUsage],\n    node_network: Mapping[NodeId, NodeNetworkInfo],\n    required_nodes: set[NodeId] | None = None,\n) -> dict[InstanceId, Instance]:\n    cycles = topology.get_cycles()\n    candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))\n\n    # Filter to cycles containing all required nodes (subset matching)\n    if required_nodes:\n        candidate_cycles = [\n            cycle\n            for cycle in candidate_cycles\n            if required_nodes.issubset(cycle.node_ids)\n        ]\n    cycles_with_sufficient_memory = filter_cycles_by_memory(\n        candidate_cycles, node_memory, command.model_card.storage_size\n    )\n    if len(cycles_with_sufficient_memory) == 0:\n        raise ValueError(\"No cycles found with sufficient memory\")\n\n    if command.sharding == Sharding.Tensor:\n        if not command.model_card.supports_tensor:\n            raise ValueError(\n                f\"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}\"\n            )\n        # TODO: the condition here for tensor parallel is not correct, but it works good enough for now.\n        kv_heads = command.model_card.num_key_value_heads\n        cycles_with_sufficient_memory = [\n            cycle\n            for cycle in cycles_with_sufficient_memory\n            if command.model_card.hidden_size % len(cycle) == 0\n            and (kv_heads is None or kv_heads % len(cycle) == 0)\n        ]\n        if not cycles_with_sufficient_memory:\n            raise ValueError(\n                f\"No tensor sharding found for model with \"\n                f\"hidden_size={command.model_card.hidden_size}\"\n                f\"{f', num_key_value_heads={kv_heads}' if kv_heads is not None else ''}\"\n                f\" across candidate cycles\"\n            )\n    if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(\n        \"mlx-community/DeepSeek-V3.1-8bit\"\n    ):\n        raise ValueError(\n            \"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)\"\n        )\n\n    smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)\n\n    smallest_rdma_cycles = [\n        cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)\n    ]\n\n    if command.instance_meta == InstanceMeta.MlxJaccl:\n        if not smallest_rdma_cycles:\n            raise ValueError(\n                \"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available\"\n            )\n        smallest_cycles = smallest_rdma_cycles\n\n    cycles_with_leaf_nodes: list[Cycle] = [\n        cycle\n        for cycle in smallest_cycles\n        if any(topology.node_is_leaf(node_id) for node_id in cycle)\n    ]\n\n    selected_cycle = max(\n        cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,\n        key=lambda cycle: sum(\n            (node_memory[node_id].ram_available for node_id in cycle),\n            start=Memory(),\n        ),\n    )\n\n    # Single-node: force Pipeline/Ring (Tensor and Jaccl require multi-node)\n    if len(selected_cycle) == 1:\n        command.instance_meta = InstanceMeta.MlxRing\n        command.sharding = Sharding.Pipeline\n\n    shard_assignments = get_shard_assignments(\n        command.model_card, selected_cycle, command.sharding, node_memory\n    )\n\n    cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)\n\n    instance_id = InstanceId()\n    target_instances = dict(deepcopy(current_instances))\n\n    match command.instance_meta:\n        case InstanceMeta.MlxJaccl:\n            # TODO(evan): shard assignments should contain information about ranks, this is ugly\n            def get_device_rank(node_id: NodeId) -> int:\n                runner_id = shard_assignments.node_to_runner[node_id]\n                shard_metadata = shard_assignments.runner_to_shard.get(runner_id)\n                assert shard_metadata is not None\n                return shard_metadata.device_rank\n\n            zero_node_ids = [\n                node_id\n                for node_id in selected_cycle.node_ids\n                if get_device_rank(node_id) == 0\n            ]\n            assert len(zero_node_ids) == 1\n            coordinator_node_id = zero_node_ids[0]\n\n            mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(\n                [node_id for node_id in selected_cycle],\n                cycle_digraph,\n            )\n            mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(\n                coordinator=coordinator_node_id,\n                coordinator_port=random_ephemeral_port(),\n                cycle_digraph=cycle_digraph,\n                node_network=node_network,\n            )\n            target_instances[instance_id] = MlxJacclInstance(\n                instance_id=instance_id,\n                shard_assignments=shard_assignments,\n                jaccl_devices=mlx_jaccl_devices,\n                jaccl_coordinators=mlx_jaccl_coordinators,\n            )\n        case InstanceMeta.MlxRing:\n            ephemeral_port = random_ephemeral_port()\n            hosts_by_node = get_mlx_ring_hosts_by_node(\n                selected_cycle=selected_cycle,\n                cycle_digraph=cycle_digraph,\n                ephemeral_port=ephemeral_port,\n                node_network=node_network,\n            )\n            target_instances[instance_id] = MlxRingInstance(\n                instance_id=instance_id,\n                shard_assignments=shard_assignments,\n                hosts_by_node=hosts_by_node,\n                ephemeral_port=ephemeral_port,\n            )\n\n    return target_instances\n\n\ndef delete_instance(\n    command: DeleteInstance,\n    current_instances: Mapping[InstanceId, Instance],\n) -> dict[InstanceId, Instance]:\n    target_instances = dict(deepcopy(current_instances))\n    if command.instance_id in target_instances:\n        del target_instances[command.instance_id]\n        return target_instances\n    raise ValueError(f\"Instance {command.instance_id} not found\")\n\n\ndef get_transition_events(\n    current_instances: Mapping[InstanceId, Instance],\n    target_instances: Mapping[InstanceId, Instance],\n    tasks: Mapping[TaskId, Task],\n) -> Sequence[Event]:\n    events: list[Event] = []\n\n    # find instances to create\n    for instance_id, instance in target_instances.items():\n        if instance_id not in current_instances:\n            events.append(\n                InstanceCreated(\n                    instance=instance,\n                )\n            )\n\n    # find instances to delete\n    for instance_id in current_instances:\n        if instance_id not in target_instances:\n            for task in tasks.values():\n                if task.instance_id == instance_id and task.task_status in [\n                    TaskStatus.Pending,\n                    TaskStatus.Running,\n                ]:\n                    events.append(\n                        TaskStatusUpdated(\n                            task_status=TaskStatus.Cancelled,\n                            task_id=task.task_id,\n                        )\n                    )\n\n            events.append(\n                InstanceDeleted(\n                    instance_id=instance_id,\n                )\n            )\n\n    return events\n\n\ndef cancel_unnecessary_downloads(\n    instances: Mapping[InstanceId, Instance],\n    download_status: Mapping[NodeId, Sequence[DownloadProgress]],\n) -> Sequence[DownloadCommand]:\n    commands: list[DownloadCommand] = []\n    currently_downloading = [\n        (k, v.shard_metadata.model_card.model_id)\n        for k, vs in download_status.items()\n        for v in vs\n        if isinstance(v, (DownloadOngoing))\n    ]\n    active_models = set(\n        (\n            node_id,\n            instance.shard_assignments.runner_to_shard[runner_id].model_card.model_id,\n        )\n        for instance in instances.values()\n        for node_id, runner_id in instance.shard_assignments.node_to_runner.items()\n    )\n    for pair in currently_downloading:\n        if pair not in active_models:\n            commands.append(CancelDownload(target_node_id=pair[0], model_id=pair[1]))\n\n    return commands\n"
  },
  {
    "path": "src/exo/master/placement_utils.py",
    "content": "from collections.abc import Generator, Mapping\n\nfrom loguru import logger\n\nfrom exo.shared.models.model_cards import ModelCard\nfrom exo.shared.topology import Topology\nfrom exo.shared.types.common import Host, NodeId\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo\nfrom exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection\nfrom exo.shared.types.worker.runners import RunnerId, ShardAssignments\nfrom exo.shared.types.worker.shards import (\n    CfgShardMetadata,\n    PipelineShardMetadata,\n    Sharding,\n    ShardMetadata,\n    TensorShardMetadata,\n)\n\n\ndef filter_cycles_by_memory(\n    cycles: list[Cycle],\n    node_memory: Mapping[NodeId, MemoryUsage],\n    required_memory: Memory,\n) -> list[Cycle]:\n    filtered_cycles: list[Cycle] = []\n    for cycle in cycles:\n        if not all(node in node_memory for node in cycle):\n            continue\n\n        total_mem = sum(\n            (node_memory[node_id].ram_available for node_id in cycle.node_ids),\n            start=Memory(),\n        )\n        if total_mem >= required_memory:\n            filtered_cycles.append(cycle)\n    return filtered_cycles\n\n\ndef get_smallest_cycles(\n    cycles: list[Cycle],\n) -> list[Cycle]:\n    min_nodes = min(len(cycle) for cycle in cycles)\n    return [cycle for cycle in cycles if len(cycle) == min_nodes]\n\n\ndef allocate_layers_proportionally(\n    total_layers: int,\n    memory_fractions: list[float],\n) -> list[int]:\n    n = len(memory_fractions)\n    if n == 0:\n        raise ValueError(\"Cannot allocate layers to an empty node list\")\n    if total_layers < n:\n        raise ValueError(\n            f\"Cannot distribute {total_layers} layers across {n} nodes \"\n            \"(need at least 1 layer per node)\"\n        )\n\n    # Largest remainder: floor each, then distribute remainder by fractional part\n    raw = [f * total_layers for f in memory_fractions]\n    result = [int(r) for r in raw]\n    by_remainder = sorted(range(n), key=lambda i: raw[i] - result[i], reverse=True)\n    for i in range(total_layers - sum(result)):\n        result[by_remainder[i]] += 1\n\n    # Ensure minimum 1 per node by taking from the largest\n    for i in range(n):\n        if result[i] == 0:\n            max_idx = max(range(n), key=lambda j: result[j])\n            assert result[max_idx] > 1\n            result[max_idx] -= 1\n            result[i] = 1\n\n    return result\n\n\ndef _validate_cycle(cycle: Cycle) -> None:\n    if not cycle.node_ids:\n        raise ValueError(\"Cannot create shard assignments for empty node cycle\")\n\n\ndef _compute_total_memory(\n    node_ids: list[NodeId],\n    node_memory: Mapping[NodeId, MemoryUsage],\n) -> Memory:\n    total_memory = sum(\n        (node_memory[node_id].ram_available for node_id in node_ids),\n        start=Memory(),\n    )\n    if total_memory.in_bytes == 0:\n        raise ValueError(\"Cannot create shard assignments: total available memory is 0\")\n    return total_memory\n\n\ndef _allocate_and_validate_layers(\n    node_ids: list[NodeId],\n    node_memory: Mapping[NodeId, MemoryUsage],\n    total_memory: Memory,\n    model_card: ModelCard,\n) -> list[int]:\n    layer_allocations = allocate_layers_proportionally(\n        total_layers=model_card.n_layers,\n        memory_fractions=[\n            node_memory[node_id].ram_available / total_memory for node_id in node_ids\n        ],\n    )\n\n    total_storage = model_card.storage_size\n    total_layers = model_card.n_layers\n    for i, node_id in enumerate(node_ids):\n        node_layers = layer_allocations[i]\n        required_memory = (total_storage * node_layers) // total_layers\n        available_memory = node_memory[node_id].ram_available\n        if required_memory > available_memory:\n            raise ValueError(\n                f\"Node {i} ({node_id}) has insufficient memory: \"\n                f\"requires {required_memory.in_gb:.2f} GB for {node_layers} layers, \"\n                f\"but only has {available_memory.in_gb:.2f} GB available\"\n            )\n\n    return layer_allocations\n\n\ndef get_shard_assignments_for_pipeline_parallel(\n    model_card: ModelCard,\n    cycle: Cycle,\n    node_memory: Mapping[NodeId, MemoryUsage],\n) -> ShardAssignments:\n    \"\"\"Create shard assignments for pipeline parallel execution.\"\"\"\n    world_size = len(cycle)\n    use_cfg_parallel = model_card.uses_cfg and world_size >= 2 and world_size % 2 == 0\n\n    if use_cfg_parallel:\n        return _get_shard_assignments_for_cfg_parallel(model_card, cycle, node_memory)\n    else:\n        return _get_shard_assignments_for_pure_pipeline(model_card, cycle, node_memory)\n\n\ndef _get_shard_assignments_for_cfg_parallel(\n    model_card: ModelCard,\n    cycle: Cycle,\n    node_memory: Mapping[NodeId, MemoryUsage],\n) -> ShardAssignments:\n    \"\"\"Create shard assignments for CFG parallel execution.\n\n    CFG parallel runs two independent pipelines. Group 0 processes the positive\n    prompt, group 1 processes the negative prompt. The ring topology places\n    group 1's ranks in reverse order so both \"last stages\" are neighbors for\n    efficient CFG exchange.\n    \"\"\"\n    _validate_cycle(cycle)\n\n    world_size = len(cycle)\n    cfg_world_size = 2\n    pipeline_world_size = world_size // cfg_world_size\n\n    # Allocate layers for one pipeline group (both groups run the same layers)\n    pipeline_node_ids = cycle.node_ids[:pipeline_world_size]\n    pipeline_memory = _compute_total_memory(pipeline_node_ids, node_memory)\n    layer_allocations = _allocate_and_validate_layers(\n        pipeline_node_ids, node_memory, pipeline_memory, model_card\n    )\n\n    # Ring topology: group 0 ascending [0,1,2,...], group 1 descending [...,2,1,0]\n    # This places both last stages as neighbors for CFG exchange.\n    position_to_cfg_pipeline = [(0, r) for r in range(pipeline_world_size)] + [\n        (1, r) for r in reversed(range(pipeline_world_size))\n    ]\n\n    runner_to_shard: dict[RunnerId, ShardMetadata] = {}\n    node_to_runner: dict[NodeId, RunnerId] = {}\n\n    for device_rank, node_id in enumerate(cycle.node_ids):\n        cfg_rank, pipeline_rank = position_to_cfg_pipeline[device_rank]\n        layers_before = sum(layer_allocations[:pipeline_rank])\n        node_layers = layer_allocations[pipeline_rank]\n\n        shard = CfgShardMetadata(\n            model_card=model_card,\n            device_rank=device_rank,\n            world_size=world_size,\n            start_layer=layers_before,\n            end_layer=layers_before + node_layers,\n            n_layers=model_card.n_layers,\n            cfg_rank=cfg_rank,\n            cfg_world_size=cfg_world_size,\n            pipeline_rank=pipeline_rank,\n            pipeline_world_size=pipeline_world_size,\n        )\n\n        runner_id = RunnerId()\n        runner_to_shard[runner_id] = shard\n        node_to_runner[node_id] = runner_id\n\n    return ShardAssignments(\n        model_id=model_card.model_id,\n        runner_to_shard=runner_to_shard,\n        node_to_runner=node_to_runner,\n    )\n\n\ndef _get_shard_assignments_for_pure_pipeline(\n    model_card: ModelCard,\n    cycle: Cycle,\n    node_memory: Mapping[NodeId, MemoryUsage],\n) -> ShardAssignments:\n    \"\"\"Create shard assignments for pure pipeline execution.\"\"\"\n    _validate_cycle(cycle)\n    total_memory = _compute_total_memory(cycle.node_ids, node_memory)\n\n    layer_allocations = _allocate_and_validate_layers(\n        cycle.node_ids, node_memory, total_memory, model_card\n    )\n\n    runner_to_shard: dict[RunnerId, ShardMetadata] = {}\n    node_to_runner: dict[NodeId, RunnerId] = {}\n\n    for pipeline_rank, node_id in enumerate(cycle.node_ids):\n        layers_before = sum(layer_allocations[:pipeline_rank])\n        node_layers = layer_allocations[pipeline_rank]\n\n        shard = PipelineShardMetadata(\n            model_card=model_card,\n            device_rank=pipeline_rank,\n            world_size=len(cycle),\n            start_layer=layers_before,\n            end_layer=layers_before + node_layers,\n            n_layers=model_card.n_layers,\n        )\n\n        runner_id = RunnerId()\n        runner_to_shard[runner_id] = shard\n        node_to_runner[node_id] = runner_id\n\n    return ShardAssignments(\n        model_id=model_card.model_id,\n        runner_to_shard=runner_to_shard,\n        node_to_runner=node_to_runner,\n    )\n\n\ndef get_shard_assignments_for_tensor_parallel(\n    model_card: ModelCard,\n    cycle: Cycle,\n):\n    total_layers = model_card.n_layers\n    world_size = len(cycle)\n    runner_to_shard: dict[RunnerId, ShardMetadata] = {}\n    node_to_runner: dict[NodeId, RunnerId] = {}\n\n    for i, node_id in enumerate(cycle):\n        shard = TensorShardMetadata(\n            model_card=model_card,\n            device_rank=i,\n            world_size=world_size,\n            start_layer=0,\n            end_layer=total_layers,\n            n_layers=total_layers,\n        )\n\n        runner_id = RunnerId()\n\n        runner_to_shard[runner_id] = shard\n        node_to_runner[node_id] = runner_id\n\n    shard_assignments = ShardAssignments(\n        model_id=model_card.model_id,\n        runner_to_shard=runner_to_shard,\n        node_to_runner=node_to_runner,\n    )\n\n    return shard_assignments\n\n\ndef get_shard_assignments(\n    model_card: ModelCard,\n    cycle: Cycle,\n    sharding: Sharding,\n    node_memory: Mapping[NodeId, MemoryUsage],\n) -> ShardAssignments:\n    match sharding:\n        case Sharding.Pipeline:\n            return get_shard_assignments_for_pipeline_parallel(\n                model_card=model_card,\n                cycle=cycle,\n                node_memory=node_memory,\n            )\n        case Sharding.Tensor:\n            return get_shard_assignments_for_tensor_parallel(\n                model_card=model_card,\n                cycle=cycle,\n            )\n\n\ndef get_mlx_jaccl_devices_matrix(\n    selected_cycle: list[NodeId],\n    cycle_digraph: Topology,\n) -> list[list[str | None]]:\n    \"\"\"Build connectivity matrix mapping device i to device j via RDMA interface names.\n\n    The matrix element [i][j] contains the interface name on device i that connects\n    to device j, or None if no connection exists or no interface name is found.\n    Diagonal elements are always None.\n    \"\"\"\n    num_nodes = len(selected_cycle)\n    matrix: list[list[str | None]] = [\n        [None for _ in range(num_nodes)] for _ in range(num_nodes)\n    ]\n\n    for i, node_i in enumerate(selected_cycle):\n        for j, node_j in enumerate(selected_cycle):\n            if i == j:\n                continue\n\n            for conn in cycle_digraph.get_all_connections_between(node_i, node_j):\n                if isinstance(conn, RDMAConnection):\n                    matrix[i][j] = conn.source_rdma_iface\n                    break\n            else:\n                raise ValueError(\n                    \"Current jaccl backend requires all-to-all RDMA connections\"\n                )\n\n    return matrix\n\n\ndef _find_connection_ip(\n    node_i: NodeId,\n    node_j: NodeId,\n    cycle_digraph: Topology,\n) -> Generator[str, None, None]:\n    \"\"\"Find all IP addresses that connect node i to node j.\"\"\"\n    for connection in cycle_digraph.get_all_connections_between(node_i, node_j):\n        if isinstance(connection, SocketConnection):\n            yield connection.sink_multiaddr.ip_address\n\n\ndef _find_ip_prioritised(\n    node_id: NodeId,\n    other_node_id: NodeId,\n    cycle_digraph: Topology,\n    node_network: Mapping[NodeId, NodeNetworkInfo],\n    ring: bool,\n) -> str | None:\n    \"\"\"Find an IP address between nodes with prioritization.\n\n    Priority: ethernet > wifi > unknown > thunderbolt\n    \"\"\"\n    ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))\n    if not ips:\n        return None\n    other_network = node_network.get(other_node_id, NodeNetworkInfo())\n    ip_to_type = {\n        iface.ip_address: iface.interface_type for iface in other_network.interfaces\n    }\n\n    # Ring should prioritise fastest connection. As a best-effort, we prioritise TB.\n    # TODO: Profile and get actual connection speeds.\n    if ring:\n        priority = {\n            \"thunderbolt\": 0,\n            \"maybe_ethernet\": 1,\n            \"ethernet\": 2,\n            \"wifi\": 3,\n            \"unknown\": 4,\n        }\n\n    # RDMA prefers ethernet coordinator\n    else:\n        priority = {\n            \"ethernet\": 0,\n            \"wifi\": 1,\n            \"unknown\": 2,\n            \"maybe_ethernet\": 3,\n            \"thunderbolt\": 4,\n        }\n    return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, \"unknown\"), 2))\n\n\ndef get_mlx_ring_hosts_by_node(\n    selected_cycle: Cycle,\n    cycle_digraph: Topology,\n    ephemeral_port: int,\n    node_network: Mapping[NodeId, NodeNetworkInfo],\n) -> dict[NodeId, list[Host]]:\n    \"\"\"Generate per-node host lists for MLX ring backend.\n\n    Each node gets a list where:\n    - Self position: Host(ip=\"0.0.0.0\", port=ephemeral_port)\n    - Left/right neighbors: actual connection IPs\n    - Non-neighbors: Host(ip=\"198.51.100.1\", port=0) placeholder (RFC 5737 TEST-NET-2)\n    \"\"\"\n    world_size = len(selected_cycle)\n    if world_size == 0:\n        return {}\n\n    hosts_by_node: dict[NodeId, list[Host]] = {}\n\n    for rank, node_id in enumerate(selected_cycle):\n        left_rank = (rank - 1) % world_size\n        right_rank = (rank + 1) % world_size\n\n        hosts_for_node: list[Host] = []\n\n        for idx, other_node_id in enumerate(selected_cycle):\n            if idx == rank:\n                hosts_for_node.append(Host(ip=\"0.0.0.0\", port=ephemeral_port))\n                continue\n\n            if idx not in {left_rank, right_rank}:\n                # Placeholder IP from RFC 5737 TEST-NET-2\n                hosts_for_node.append(Host(ip=\"198.51.100.1\", port=0))\n                continue\n\n            connection_ip = _find_ip_prioritised(\n                node_id, other_node_id, cycle_digraph, node_network, ring=True\n            )\n            if connection_ip is None:\n                raise ValueError(\n                    \"MLX ring backend requires connectivity between neighbouring nodes\"\n                )\n\n            hosts_for_node.append(Host(ip=connection_ip, port=ephemeral_port))\n\n        hosts_by_node[node_id] = hosts_for_node\n\n    return hosts_by_node\n\n\ndef get_mlx_jaccl_coordinators(\n    coordinator: NodeId,\n    coordinator_port: int,\n    cycle_digraph: Topology,\n    node_network: Mapping[NodeId, NodeNetworkInfo],\n) -> dict[NodeId, str]:\n    \"\"\"Get the coordinator addresses for MLX JACCL (rank 0 device).\n\n    Select an IP address that each node can reach for the rank 0 node. Returns\n    address in format \"X.X.X.X:PORT\" per node.\n    \"\"\"\n    logger.debug(f\"Selecting coordinator: {coordinator}\")\n\n    def get_ip_for_node(n: NodeId) -> str:\n        if n == coordinator:\n            return \"0.0.0.0\"\n\n        ip = _find_ip_prioritised(\n            n, coordinator, cycle_digraph, node_network, ring=False\n        )\n        if ip is not None:\n            return ip\n\n        raise ValueError(\n            \"Current jaccl backend requires all participating devices to be able to communicate\"\n        )\n\n    return {\n        n: f\"{get_ip_for_node(n)}:{coordinator_port}\"\n        for n in cycle_digraph.list_nodes()\n    }\n"
  },
  {
    "path": "src/exo/master/tests/conftest.py",
    "content": "from exo.shared.types.multiaddr import Multiaddr\nfrom exo.shared.types.profiling import (\n    MemoryUsage,\n    NetworkInterfaceInfo,\n    NodeNetworkInfo,\n)\nfrom exo.shared.types.topology import RDMAConnection, SocketConnection\n\n\ndef create_node_memory(memory: int) -> MemoryUsage:\n    return MemoryUsage.from_bytes(\n        ram_total=1000,\n        ram_available=memory,\n        swap_total=1000,\n        swap_available=1000,\n    )\n\n\ndef create_node_network() -> NodeNetworkInfo:\n    return NodeNetworkInfo(\n        interfaces=[\n            NetworkInterfaceInfo(name=\"en0\", ip_address=f\"169.254.0.{i}\")\n            for i in range(10)\n        ]\n    )\n\n\ndef create_socket_connection(ip: int, sink_port: int = 1234) -> SocketConnection:\n    return SocketConnection(\n        sink_multiaddr=Multiaddr(address=f\"/ip4/169.254.0.{ip}/tcp/{sink_port}\"),\n    )\n\n\ndef create_rdma_connection(iface: int) -> RDMAConnection:\n    return RDMAConnection(\n        source_rdma_iface=f\"rdma_en{iface}\", sink_rdma_iface=f\"rdma_en{iface}\"\n    )\n"
  },
  {
    "path": "src/exo/master/tests/test_master.py",
    "content": "from datetime import datetime, timezone\nfrom typing import Sequence\n\nimport anyio\nimport pytest\nfrom loguru import logger\n\nfrom exo.master.main import Master\nfrom exo.routing.router import get_node_id_keypair\nfrom exo.shared.models.model_cards import ModelCard, ModelTask\nfrom exo.shared.types.commands import (\n    CommandId,\n    ForwarderCommand,\n    ForwarderDownloadCommand,\n    PlaceInstance,\n    TextGeneration,\n)\nfrom exo.shared.types.common import ModelId, NodeId, SessionId, SystemId\nfrom exo.shared.types.events import (\n    Event,\n    GlobalForwarderEvent,\n    IndexedEvent,\n    InstanceCreated,\n    LocalForwarderEvent,\n    NodeGatheredInfo,\n    TaskCreated,\n)\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.profiling import (\n    MemoryUsage,\n)\nfrom exo.shared.types.tasks import TaskStatus\nfrom exo.shared.types.tasks import TextGeneration as TextGenerationTask\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\nfrom exo.shared.types.worker.instances import (\n    InstanceMeta,\n    MlxRingInstance,\n    ShardAssignments,\n)\nfrom exo.shared.types.worker.shards import PipelineShardMetadata, Sharding\nfrom exo.utils.channels import channel\n\n\n@pytest.mark.asyncio\nasync def test_master():\n    keypair = get_node_id_keypair()\n    node_id = NodeId(keypair.to_node_id())\n    session_id = SessionId(master_node_id=node_id, election_clock=0)\n\n    ge_sender, global_event_receiver = channel[GlobalForwarderEvent]()\n    command_sender, co_receiver = channel[ForwarderCommand]()\n    local_event_sender, le_receiver = channel[LocalForwarderEvent]()\n    fcds, _fcdr = channel[ForwarderDownloadCommand]()\n    ev_send, ev_recv = channel[Event]()\n\n    async def mock_event_router():\n        idx = 0\n        sid = SystemId()\n        with ev_recv as master_events:\n            async for event in master_events:\n                await local_event_sender.send(\n                    LocalForwarderEvent(\n                        origin=sid,\n                        origin_idx=idx,\n                        session=session_id,\n                        event=event,\n                    )\n                )\n                idx += 1\n\n    all_events: list[IndexedEvent] = []\n\n    def _get_events() -> Sequence[IndexedEvent]:\n        orig_events = global_event_receiver.collect()\n        for e in orig_events:\n            all_events.append(\n                IndexedEvent(\n                    event=e.event,\n                    idx=len(all_events),  # origin=e.origin,\n                )\n            )\n        return all_events\n\n    master = Master(\n        node_id,\n        session_id,\n        event_sender=ev_send,\n        global_event_sender=ge_sender,\n        local_event_receiver=le_receiver,\n        command_receiver=co_receiver,\n        download_command_sender=fcds,\n    )\n    logger.info(\"run the master\")\n    async with anyio.create_task_group() as tg:\n        tg.start_soon(master.run)\n        tg.start_soon(mock_event_router)\n\n        # inject a NodeGatheredInfo event\n        logger.info(\"inject a NodeGatheredInfo event\")\n        await local_event_sender.send(\n            LocalForwarderEvent(\n                origin_idx=0,\n                origin=SystemId(\"Worker\"),\n                session=session_id,\n                event=(\n                    NodeGatheredInfo(\n                        when=str(datetime.now(tz=timezone.utc)),\n                        node_id=node_id,\n                        info=MemoryUsage(\n                            ram_total=Memory.from_bytes(678948 * 1024),\n                            ram_available=Memory.from_bytes(678948 * 1024),\n                            swap_total=Memory.from_bytes(0),\n                            swap_available=Memory.from_bytes(0),\n                        ),\n                    )\n                ),\n            )\n        )\n\n        # wait for initial topology event\n        logger.info(\"wait for initial topology event\")\n        while len(list(master.state.topology.list_nodes())) == 0:\n            await anyio.sleep(0.001)\n        while len(master.state.node_memory) == 0:\n            await anyio.sleep(0.001)\n\n        logger.info(\"inject a CreateInstance Command\")\n        await command_sender.send(\n            ForwarderCommand(\n                origin=SystemId(\"API\"),\n                command=(\n                    PlaceInstance(\n                        command_id=CommandId(),\n                        model_card=ModelCard(\n                            model_id=ModelId(\"llama-3.2-1b\"),\n                            n_layers=16,\n                            storage_size=Memory.from_bytes(678948),\n                            hidden_size=7168,\n                            supports_tensor=True,\n                            tasks=[ModelTask.TextGeneration],\n                        ),\n                        sharding=Sharding.Pipeline,\n                        instance_meta=InstanceMeta.MlxRing,\n                        min_nodes=1,\n                    )\n                ),\n            )\n        )\n        logger.info(\"wait for an instance\")\n        while len(master.state.instances.keys()) == 0:\n            await anyio.sleep(0.001)\n        logger.info(\"inject a TextGeneration Command\")\n        await command_sender.send(\n            ForwarderCommand(\n                origin=SystemId(\"API\"),\n                command=(\n                    TextGeneration(\n                        command_id=CommandId(),\n                        task_params=TextGenerationTaskParams(\n                            model=ModelId(\"llama-3.2-1b\"),\n                            input=[\n                                InputMessage(role=\"user\", content=\"Hello, how are you?\")\n                            ],\n                        ),\n                    )\n                ),\n            )\n        )\n        while len(_get_events()) < 3:\n            await anyio.sleep(0.01)\n\n        events = _get_events()\n        assert len(events) == 3\n        assert events[0].idx == 0\n        assert events[1].idx == 1\n        assert events[2].idx == 2\n        assert isinstance(events[0].event, NodeGatheredInfo)\n        assert isinstance(events[1].event, InstanceCreated)\n        created_instance = events[1].event.instance\n        assert isinstance(created_instance, MlxRingInstance)\n        runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]\n        # Validate the shard assignments\n        expected_shard_assignments = ShardAssignments(\n            model_id=ModelId(\"llama-3.2-1b\"),\n            runner_to_shard={\n                (runner_id): PipelineShardMetadata(\n                    start_layer=0,\n                    end_layer=16,\n                    n_layers=16,\n                    model_card=ModelCard(\n                        model_id=ModelId(\"llama-3.2-1b\"),\n                        n_layers=16,\n                        storage_size=Memory.from_bytes(678948),\n                        hidden_size=7168,\n                        supports_tensor=True,\n                        tasks=[ModelTask.TextGeneration],\n                    ),\n                    device_rank=0,\n                    world_size=1,\n                )\n            },\n            node_to_runner={node_id: runner_id},\n        )\n        assert created_instance.shard_assignments == expected_shard_assignments\n        # For single-node, hosts_by_node should have one entry with self-binding\n        assert len(created_instance.hosts_by_node) == 1\n        assert node_id in created_instance.hosts_by_node\n        assert len(created_instance.hosts_by_node[node_id]) == 1\n        assert created_instance.hosts_by_node[node_id][0].ip == \"0.0.0.0\"\n        assert created_instance.ephemeral_port > 0\n        assert isinstance(events[2].event, TaskCreated)\n        assert events[2].event.task.task_status == TaskStatus.Pending\n        assert isinstance(events[2].event.task, TextGenerationTask)\n        assert events[2].event.task.task_params == TextGenerationTaskParams(\n            model=ModelId(\"llama-3.2-1b\"),\n            input=[InputMessage(role=\"user\", content=\"Hello, how are you?\")],\n        )\n\n        ev_send.close()\n        await master.shutdown()\n"
  },
  {
    "path": "src/exo/master/tests/test_placement.py",
    "content": "import pytest\n\nfrom exo.master.placement import (\n    get_transition_events,\n    place_instance,\n)\nfrom exo.master.tests.conftest import (\n    create_node_memory,\n    create_node_network,\n    create_rdma_connection,\n    create_socket_connection,\n)\nfrom exo.shared.models.model_cards import ModelCard, ModelId, ModelTask\nfrom exo.shared.topology import Topology\nfrom exo.shared.types.commands import PlaceInstance\nfrom exo.shared.types.common import CommandId, NodeId\nfrom exo.shared.types.events import (\n    InstanceCreated,\n    InstanceDeleted,\n    TaskStatusUpdated,\n)\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.multiaddr import Multiaddr\nfrom exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo\nfrom exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\nfrom exo.shared.types.topology import Connection, SocketConnection\nfrom exo.shared.types.worker.instances import (\n    Instance,\n    InstanceId,\n    InstanceMeta,\n    MlxJacclInstance,\n    MlxRingInstance,\n)\nfrom exo.shared.types.worker.runners import ShardAssignments\nfrom exo.shared.types.worker.shards import Sharding\n\n\n@pytest.fixture\ndef instance() -> Instance:\n    return MlxRingInstance(\n        instance_id=InstanceId(),\n        shard_assignments=ShardAssignments(\n            model_id=ModelId(\"test-model\"), runner_to_shard={}, node_to_runner={}\n        ),\n        hosts_by_node={},\n        ephemeral_port=50000,\n    )\n\n\n@pytest.fixture\ndef model_card() -> ModelCard:\n    return ModelCard(\n        model_id=ModelId(\"test-model\"),\n        storage_size=Memory.from_kb(1000),\n        n_layers=10,\n        hidden_size=30,\n        supports_tensor=True,\n        tasks=[ModelTask.TextGeneration],\n    )\n\n\ndef place_instance_command(model_card: ModelCard) -> PlaceInstance:\n    return PlaceInstance(\n        command_id=CommandId(),\n        model_card=model_card,\n        sharding=Sharding.Pipeline,\n        instance_meta=InstanceMeta.MlxRing,\n        min_nodes=1,\n    )\n\n\n@pytest.mark.parametrize(\n    \"available_memory,total_layers,expected_layers\",\n    [\n        ((500, 500, 1000), 12, (3, 3, 6)),\n        ((500, 500, 500), 12, (4, 4, 4)),\n        ((312, 468, 1092), 12, (2, 3, 7)),\n    ],\n)\ndef test_get_instance_placements_create_instance(\n    available_memory: tuple[int, int, int],\n    total_layers: int,\n    expected_layers: tuple[int, int, int],\n    model_card: ModelCard,\n):\n    # arrange\n    model_card.n_layers = total_layers\n    model_card.storage_size = Memory.from_bytes(\n        sum(available_memory)\n    )  # make it exactly fit across all nodes\n    topology = Topology()\n\n    cic = place_instance_command(model_card)\n    node_id_a = NodeId()\n    node_id_b = NodeId()\n    node_id_c = NodeId()\n\n    # fully connected (directed) between the 3 nodes\n    conn_a_b = Connection(\n        source=node_id_a, sink=node_id_b, edge=create_socket_connection(1)\n    )\n    conn_b_c = Connection(\n        source=node_id_b, sink=node_id_c, edge=create_socket_connection(2)\n    )\n    conn_c_a = Connection(\n        source=node_id_c, sink=node_id_a, edge=create_socket_connection(3)\n    )\n    conn_c_b = Connection(\n        source=node_id_c, sink=node_id_b, edge=create_socket_connection(4)\n    )\n    conn_a_c = Connection(\n        source=node_id_a, sink=node_id_c, edge=create_socket_connection(5)\n    )\n    conn_b_a = Connection(\n        source=node_id_b, sink=node_id_a, edge=create_socket_connection(6)\n    )\n\n    node_memory = {\n        node_id_a: create_node_memory(available_memory[0]),\n        node_id_b: create_node_memory(available_memory[1]),\n        node_id_c: create_node_memory(available_memory[2]),\n    }\n    node_network = {\n        node_id_a: create_node_network(),\n        node_id_b: create_node_network(),\n        node_id_c: create_node_network(),\n    }\n    topology.add_node(node_id_a)\n    topology.add_node(node_id_b)\n    topology.add_node(node_id_c)\n    topology.add_connection(conn_a_b)\n    topology.add_connection(conn_b_c)\n    topology.add_connection(conn_c_a)\n    topology.add_connection(conn_c_b)\n    topology.add_connection(conn_a_c)\n    topology.add_connection(conn_b_a)\n\n    # act\n    placements = place_instance(cic, topology, {}, node_memory, node_network)\n\n    # assert\n    assert len(placements) == 1\n    instance_id = list(placements.keys())[0]\n    instance = placements[instance_id]\n    assert instance.shard_assignments.model_id == model_card.model_id\n\n    runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]\n    runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]\n    runner_id_c = instance.shard_assignments.node_to_runner[node_id_c]\n\n    shard_a = instance.shard_assignments.runner_to_shard[runner_id_a]\n    shard_b = instance.shard_assignments.runner_to_shard[runner_id_b]\n    shard_c = instance.shard_assignments.runner_to_shard[runner_id_c]\n\n    assert shard_a.end_layer - shard_a.start_layer == expected_layers[0]\n    assert shard_b.end_layer - shard_b.start_layer == expected_layers[1]\n    assert shard_c.end_layer - shard_c.start_layer == expected_layers[2]\n\n    shards = [shard_a, shard_b, shard_c]\n    shards_sorted = sorted(shards, key=lambda s: s.start_layer)\n    assert shards_sorted[0].start_layer == 0\n    assert shards_sorted[-1].end_layer == total_layers\n\n\ndef test_get_instance_placements_one_node_exact_fit() -> None:\n    topology = Topology()\n    node_id = NodeId()\n    topology.add_node(node_id)\n    node_memory = {node_id: create_node_memory(1000 * 1024)}\n    node_network = {node_id: create_node_network()}\n    cic = place_instance_command(\n        ModelCard(\n            model_id=ModelId(\"test-model\"),\n            storage_size=Memory.from_kb(1000),\n            n_layers=10,\n            hidden_size=1000,\n            supports_tensor=True,\n            tasks=[ModelTask.TextGeneration],\n        ),\n    )\n    placements = place_instance(cic, topology, {}, node_memory, node_network)\n\n    assert len(placements) == 1\n    instance_id = list(placements.keys())[0]\n    instance = placements[instance_id]\n    assert instance.shard_assignments.model_id == \"test-model\"\n    assert len(instance.shard_assignments.node_to_runner) == 1\n    assert len(instance.shard_assignments.runner_to_shard) == 1\n    assert len(instance.shard_assignments.runner_to_shard) == 1\n\n\ndef test_get_instance_placements_one_node_fits_with_extra_memory() -> None:\n    topology = Topology()\n    node_id = NodeId()\n    topology.add_node(node_id)\n    node_memory = {node_id: create_node_memory(1001 * 1024)}\n    node_network = {node_id: create_node_network()}\n    cic = place_instance_command(\n        ModelCard(\n            model_id=ModelId(\"test-model\"),\n            storage_size=Memory.from_kb(1000),\n            n_layers=10,\n            hidden_size=1000,\n            supports_tensor=True,\n            tasks=[ModelTask.TextGeneration],\n        ),\n    )\n    placements = place_instance(cic, topology, {}, node_memory, node_network)\n\n    assert len(placements) == 1\n    instance_id = list(placements.keys())[0]\n    instance = placements[instance_id]\n    assert instance.shard_assignments.model_id == \"test-model\"\n    assert len(instance.shard_assignments.node_to_runner) == 1\n    assert len(instance.shard_assignments.runner_to_shard) == 1\n    assert len(instance.shard_assignments.runner_to_shard) == 1\n\n\ndef test_get_instance_placements_one_node_not_fit() -> None:\n    topology = Topology()\n    node_id = NodeId()\n    topology.add_node(node_id)\n    node_memory = {node_id: create_node_memory(1000 * 1024)}\n    node_network = {node_id: create_node_network()}\n    cic = place_instance_command(\n        model_card=ModelCard(\n            model_id=ModelId(\"test-model\"),\n            storage_size=Memory.from_kb(1001),\n            n_layers=10,\n            hidden_size=1000,\n            supports_tensor=True,\n            tasks=[ModelTask.TextGeneration],\n        ),\n    )\n\n    with pytest.raises(ValueError, match=\"No cycles found with sufficient memory\"):\n        place_instance(cic, topology, {}, node_memory, node_network)\n\n\ndef test_get_transition_events_no_change(instance: Instance):\n    # arrange\n    instance_id = InstanceId()\n    current_instances = {instance_id: instance}\n    target_instances = {instance_id: instance}\n\n    # act\n    events = get_transition_events(current_instances, target_instances, {})\n\n    # assert\n    assert len(events) == 0\n\n\ndef test_get_transition_events_create_instance(instance: Instance):\n    # arrange\n    instance_id = InstanceId()\n    current_instances: dict[InstanceId, Instance] = {}\n    target_instances: dict[InstanceId, Instance] = {instance_id: instance}\n\n    # act\n    events = get_transition_events(current_instances, target_instances, {})\n\n    # assert\n    assert len(events) == 1\n    assert isinstance(events[0], InstanceCreated)\n\n\ndef test_get_transition_events_delete_instance(instance: Instance):\n    # arrange\n    instance_id = InstanceId()\n    current_instances: dict[InstanceId, Instance] = {instance_id: instance}\n    target_instances: dict[InstanceId, Instance] = {}\n\n    # act\n    events = get_transition_events(current_instances, target_instances, {})\n\n    # assert\n    assert len(events) == 1\n    assert isinstance(events[0], InstanceDeleted)\n    assert events[0].instance_id == instance_id\n\n\ndef test_placement_selects_leaf_nodes(\n    model_card: ModelCard,\n):\n    # arrange\n    topology = Topology()\n\n    model_card.storage_size = Memory.from_bytes(1000)\n\n    node_id_a = NodeId()\n    node_id_b = NodeId()\n    node_id_c = NodeId()\n    node_id_d = NodeId()\n\n    node_memory = {\n        node_id_a: create_node_memory(500),\n        node_id_b: create_node_memory(600),\n        node_id_c: create_node_memory(600),\n        node_id_d: create_node_memory(500),\n    }\n    node_network = {\n        node_id_a: create_node_network(),\n        node_id_b: create_node_network(),\n        node_id_c: create_node_network(),\n        node_id_d: create_node_network(),\n    }\n\n    topology.add_node(node_id_a)\n    topology.add_node(node_id_b)\n    topology.add_node(node_id_c)\n    topology.add_node(node_id_d)\n\n    # Daisy chain topology (directed)\n    topology.add_connection(\n        Connection(source=node_id_a, sink=node_id_b, edge=create_socket_connection(1))\n    )\n    topology.add_connection(\n        Connection(source=node_id_b, sink=node_id_a, edge=create_socket_connection(1))\n    )\n    topology.add_connection(\n        Connection(source=node_id_b, sink=node_id_c, edge=create_socket_connection(1))\n    )\n    topology.add_connection(\n        Connection(source=node_id_c, sink=node_id_b, edge=create_socket_connection(1))\n    )\n    topology.add_connection(\n        Connection(source=node_id_c, sink=node_id_d, edge=create_socket_connection(1))\n    )\n    topology.add_connection(\n        Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))\n    )\n\n    cic = place_instance_command(model_card=model_card)\n\n    # act\n    placements = place_instance(cic, topology, {}, node_memory, node_network)\n\n    # assert\n    assert len(placements) == 1\n    instance = list(placements.values())[0]\n\n    assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())\n    assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(\n        (\n            node_id_c,\n            node_id_d,\n        )\n    )\n\n\ndef test_tensor_rdma_backend_connectivity_matrix(\n    model_card: ModelCard,\n):\n    # arrange\n    topology = Topology()\n    model_card.n_layers = 12\n    model_card.storage_size = Memory.from_bytes(1500)\n\n    node_a = NodeId()\n    node_b = NodeId()\n    node_c = NodeId()\n\n    node_memory = {\n        node_a: create_node_memory(500),\n        node_b: create_node_memory(500),\n        node_c: create_node_memory(500),\n    }\n\n    ethernet_interface = NetworkInterfaceInfo(\n        name=\"en0\",\n        ip_address=\"10.0.0.1\",\n    )\n    ethernet_conn = SocketConnection(\n        sink_multiaddr=Multiaddr(address=\"/ip4/10.0.0.1/tcp/8000\")\n    )\n\n    node_network = {\n        node_a: NodeNetworkInfo(interfaces=[ethernet_interface]),\n        node_b: NodeNetworkInfo(interfaces=[ethernet_interface]),\n        node_c: NodeNetworkInfo(interfaces=[ethernet_interface]),\n    }\n\n    topology.add_node(node_a)\n    topology.add_node(node_b)\n    topology.add_node(node_c)\n\n    # RDMA connections (directed)\n    topology.add_connection(\n        Connection(source=node_a, sink=node_b, edge=create_rdma_connection(3))\n    )\n    topology.add_connection(\n        Connection(source=node_b, sink=node_a, edge=create_rdma_connection(3))\n    )\n    topology.add_connection(\n        Connection(source=node_b, sink=node_c, edge=create_rdma_connection(4))\n    )\n    topology.add_connection(\n        Connection(source=node_c, sink=node_b, edge=create_rdma_connection(4))\n    )\n    topology.add_connection(\n        Connection(source=node_a, sink=node_c, edge=create_rdma_connection(5))\n    )\n    topology.add_connection(\n        Connection(source=node_c, sink=node_a, edge=create_rdma_connection(5))\n    )\n\n    # Ethernet connections (directed)\n    topology.add_connection(Connection(source=node_a, sink=node_b, edge=ethernet_conn))\n    topology.add_connection(Connection(source=node_b, sink=node_c, edge=ethernet_conn))\n    topology.add_connection(Connection(source=node_c, sink=node_a, edge=ethernet_conn))\n    topology.add_connection(Connection(source=node_a, sink=node_c, edge=ethernet_conn))\n    topology.add_connection(Connection(source=node_b, sink=node_a, edge=ethernet_conn))\n    topology.add_connection(Connection(source=node_c, sink=node_b, edge=ethernet_conn))\n\n    cic = PlaceInstance(\n        sharding=Sharding.Tensor,\n        instance_meta=InstanceMeta.MlxJaccl,\n        command_id=CommandId(),\n        model_card=model_card,\n        min_nodes=1,\n    )\n\n    # act\n    placements = place_instance(cic, topology, {}, node_memory, node_network)\n\n    # assert\n    assert len(placements) == 1\n    instance_id = list(placements.keys())[0]\n    instance = placements[instance_id]\n\n    assert isinstance(instance, MlxJacclInstance)\n\n    assert instance.jaccl_devices is not None\n    assert instance.jaccl_coordinators is not None\n\n    matrix = instance.jaccl_devices\n    assert len(matrix) == 3\n    for i in range(3):\n        assert matrix[i][i] is None\n\n    assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())\n    node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}\n\n    idx_a = node_to_idx[node_a]\n    idx_b = node_to_idx[node_b]\n    idx_c = node_to_idx[node_c]\n\n    assert matrix[idx_a][idx_b] == \"rdma_en3\"\n    assert matrix[idx_b][idx_c] == \"rdma_en4\"\n    assert matrix[idx_c][idx_a] == \"rdma_en5\"\n\n    # Verify coordinators are set for all nodes\n    assert len(instance.jaccl_coordinators) == 3\n    for node_id in assigned_nodes:\n        assert node_id in instance.jaccl_coordinators\n        coordinator = instance.jaccl_coordinators[node_id]\n        assert \":\" in coordinator\n        # Rank 0 node should use 0.0.0.0, others should use connection-specific IPs\n        if node_id == assigned_nodes[0]:\n            assert coordinator.startswith(\"0.0.0.0:\")\n        else:\n            ip_part = coordinator.split(\":\")[0]\n            assert len(ip_part.split(\".\")) == 4\n\n\ndef _make_task(\n    instance_id: InstanceId,\n    status: TaskStatus = TaskStatus.Running,\n) -> TextGeneration:\n    return TextGeneration(\n        task_id=TaskId(),\n        task_status=status,\n        instance_id=instance_id,\n        command_id=CommandId(),\n        task_params=TextGenerationTaskParams(\n            model=ModelId(\"test-model\"),\n            input=[InputMessage(role=\"user\", content=\"hello\")],\n        ),\n    )\n\n\ndef test_get_transition_events_delete_instance_cancels_running_tasks(\n    instance: Instance,\n):\n    # arrange\n    instance_id = InstanceId()\n    current_instances: dict[InstanceId, Instance] = {instance_id: instance}\n    target_instances: dict[InstanceId, Instance] = {}\n    task = _make_task(instance_id, TaskStatus.Running)\n    tasks = {task.task_id: task}\n\n    # act\n    events = get_transition_events(current_instances, target_instances, tasks)\n\n    # assert – cancellation event should come before the deletion event\n    assert len(events) == 2\n    assert isinstance(events[0], TaskStatusUpdated)\n    assert events[0].task_id == task.task_id\n    assert events[0].task_status == TaskStatus.Cancelled\n    assert isinstance(events[1], InstanceDeleted)\n    assert events[1].instance_id == instance_id\n\n\ndef test_get_transition_events_delete_instance_cancels_pending_tasks(\n    instance: Instance,\n):\n    # arrange\n    instance_id = InstanceId()\n    current_instances: dict[InstanceId, Instance] = {instance_id: instance}\n    target_instances: dict[InstanceId, Instance] = {}\n    task = _make_task(instance_id, TaskStatus.Pending)\n    tasks = {task.task_id: task}\n\n    # act\n    events = get_transition_events(current_instances, target_instances, tasks)\n\n    # assert\n    assert len(events) == 2\n    assert isinstance(events[0], TaskStatusUpdated)\n    assert events[0].task_id == task.task_id\n    assert events[0].task_status == TaskStatus.Cancelled\n    assert isinstance(events[1], InstanceDeleted)\n\n\ndef test_get_transition_events_delete_instance_ignores_completed_tasks(\n    instance: Instance,\n):\n    # arrange\n    instance_id = InstanceId()\n    current_instances: dict[InstanceId, Instance] = {instance_id: instance}\n    target_instances: dict[InstanceId, Instance] = {}\n    tasks = {\n        t.task_id: t\n        for t in [\n            _make_task(instance_id, TaskStatus.Complete),\n            _make_task(instance_id, TaskStatus.Failed),\n            _make_task(instance_id, TaskStatus.TimedOut),\n            _make_task(instance_id, TaskStatus.Cancelled),\n        ]\n    }\n\n    # act\n    events = get_transition_events(current_instances, target_instances, tasks)\n\n    # assert – only the InstanceDeleted event, no cancellations\n    assert len(events) == 1\n    assert isinstance(events[0], InstanceDeleted)\n\n\ndef test_get_transition_events_delete_instance_cancels_only_matching_tasks(\n    instance: Instance,\n):\n    # arrange\n    instance_id_a = InstanceId()\n    instance_id_b = InstanceId()\n    current_instances: dict[InstanceId, Instance] = {\n        instance_id_a: instance,\n        instance_id_b: instance,\n    }\n    # only delete instance A, keep instance B\n    target_instances: dict[InstanceId, Instance] = {instance_id_b: instance}\n\n    task_a = _make_task(instance_id_a, TaskStatus.Running)\n    task_b = _make_task(instance_id_b, TaskStatus.Running)\n    tasks = {task_a.task_id: task_a, task_b.task_id: task_b}\n\n    # act\n    events = get_transition_events(current_instances, target_instances, tasks)\n\n    # assert – only task_a should be cancelled\n    cancel_events = [e for e in events if isinstance(e, TaskStatusUpdated)]\n    delete_events = [e for e in events if isinstance(e, InstanceDeleted)]\n    assert len(cancel_events) == 1\n    assert cancel_events[0].task_id == task_a.task_id\n    assert cancel_events[0].task_status == TaskStatus.Cancelled\n    assert len(delete_events) == 1\n    assert delete_events[0].instance_id == instance_id_a\n"
  },
  {
    "path": "src/exo/master/tests/test_placement_utils.py",
    "content": "import pytest\n\nfrom exo.master.placement_utils import (\n    allocate_layers_proportionally,\n    filter_cycles_by_memory,\n    get_mlx_jaccl_coordinators,\n    get_shard_assignments,\n    get_shard_assignments_for_pipeline_parallel,\n    get_smallest_cycles,\n)\nfrom exo.master.tests.conftest import (\n    create_node_memory,\n    create_socket_connection,\n)\nfrom exo.shared.models.model_cards import ModelCard, ModelId, ModelTask\nfrom exo.shared.topology import Topology\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.profiling import (\n    NetworkInterfaceInfo,\n    NodeNetworkInfo,\n)\nfrom exo.shared.types.topology import Connection, SocketConnection\nfrom exo.shared.types.worker.shards import (\n    CfgShardMetadata,\n    PipelineShardMetadata,\n    Sharding,\n)\n\n\ndef test_filter_cycles_by_memory():\n    # arrange\n    node1_id = NodeId()\n    node2_id = NodeId()\n    connection1 = Connection(\n        source=node1_id, sink=node2_id, edge=create_socket_connection(1)\n    )\n    connection2 = Connection(\n        source=node2_id, sink=node1_id, edge=create_socket_connection(2)\n    )\n\n    node1_mem = create_node_memory(1000 * 1024)\n    node2_mem = create_node_memory(1000 * 1024)\n    node_memory = {node1_id: node1_mem, node2_id: node2_mem}\n\n    topology = Topology()\n    topology.add_node(node1_id)\n    topology.add_node(node2_id)\n    topology.add_connection(connection1)\n    topology.add_connection(connection2)\n\n    cycles = [c for c in topology.get_cycles() if len(c) != 1]\n    assert len(cycles) == 1\n    assert len(cycles[0]) == 2\n\n    # act\n    filtered_cycles = filter_cycles_by_memory(cycles, node_memory, Memory.from_bytes(1))\n\n    # assert\n    assert len(filtered_cycles) == 1\n    assert len(filtered_cycles[0]) == 2\n    assert set(n for n in filtered_cycles[0]) == {node1_id, node2_id}\n\n\ndef test_filter_cycles_by_insufficient_memory():\n    # arrange\n    node1_id = NodeId()\n    node2_id = NodeId()\n    connection1 = Connection(\n        source=node1_id, sink=node2_id, edge=create_socket_connection(1)\n    )\n    connection2 = Connection(\n        source=node2_id, sink=node1_id, edge=create_socket_connection(2)\n    )\n\n    node1_mem = create_node_memory(1000 * 1024)\n    node2_mem = create_node_memory(1000 * 1024)\n    node_memory = {node1_id: node1_mem, node2_id: node2_mem}\n\n    topology = Topology()\n    topology.add_node(node1_id)\n    topology.add_node(node2_id)\n    topology.add_connection(connection1)\n    topology.add_connection(connection2)\n\n    # act\n    filtered_cycles = filter_cycles_by_memory(\n        topology.get_cycles(), node_memory, Memory.from_kb(2001)\n    )\n\n    # assert\n    assert len(filtered_cycles) == 0\n\n\ndef test_filter_multiple_cycles_by_memory():\n    # arrange\n    node_a_id = NodeId()\n    node_b_id = NodeId()\n    node_c_id = NodeId()\n    connection1 = Connection(\n        source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)\n    )\n    connection2 = Connection(\n        source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)\n    )\n    connection3 = Connection(\n        source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)\n    )\n    connection4 = Connection(\n        source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)\n    )\n\n    node_a_mem = create_node_memory(500 * 1024)\n    node_b_mem = create_node_memory(500 * 1024)\n    node_c_mem = create_node_memory(1000 * 1024)\n    node_memory = {\n        node_a_id: node_a_mem,\n        node_b_id: node_b_mem,\n        node_c_id: node_c_mem,\n    }\n\n    topology = Topology()\n    topology.add_node(node_a_id)\n    topology.add_node(node_b_id)\n    topology.add_node(node_c_id)\n    topology.add_connection(connection1)\n    topology.add_connection(connection2)\n    topology.add_connection(connection3)\n    topology.add_connection(connection4)\n\n    cycles = topology.get_cycles()\n\n    # act\n    filtered_cycles = filter_cycles_by_memory(cycles, node_memory, Memory.from_kb(1500))\n\n    # assert\n    assert len(filtered_cycles) == 1\n    assert len(filtered_cycles[0]) == 3\n    assert set(n for n in filtered_cycles[0]) == {\n        node_a_id,\n        node_b_id,\n        node_c_id,\n    }\n\n\ndef test_get_smallest_cycles():\n    # arrange\n    node_a_id = NodeId()\n    node_b_id = NodeId()\n    node_c_id = NodeId()\n\n    topology = Topology()\n    topology.add_node(node_a_id)\n    topology.add_node(node_b_id)\n    topology.add_node(node_c_id)\n\n    connection1 = Connection(\n        source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)\n    )\n    connection2 = Connection(\n        source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)\n    )\n    connection3 = Connection(\n        source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)\n    )\n    connection4 = Connection(\n        source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)\n    )\n\n    topology.add_connection(connection1)\n    topology.add_connection(connection2)\n    topology.add_connection(connection3)\n    topology.add_connection(connection4)\n\n    cycles = [c for c in topology.get_cycles() if len(c) != 1]  # ignore singletons\n\n    # act\n    smallest_cycles = get_smallest_cycles(cycles)\n\n    # assert\n    assert len(smallest_cycles) == 1\n    assert len(smallest_cycles[0]) == 2\n    assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}\n\n\n@pytest.mark.parametrize(\n    \"available_memory,total_layers,expected_layers\",\n    [\n        ((500, 500, 1000), 12, (3, 3, 6)),\n        ((500, 500, 500), 12, (4, 4, 4)),\n        ((312, 518, 1024), 12, (2, 3, 7)),\n        # Edge case: one node has ~90% of memory - should not over-allocate.\n        # Each node must have enough memory for at least 1 layer (50 KB = 1000/20).\n        ((900, 50, 50), 20, (18, 1, 1)),\n    ],\n)\ndef test_get_shard_assignments(\n    available_memory: tuple[int, int, int],\n    total_layers: int,\n    expected_layers: tuple[int, int, int],\n):\n    # arrange\n    node_a_id = NodeId()\n    node_b_id = NodeId()\n    node_c_id = NodeId()\n\n    # create connections (A -> B -> C -> A forms a 3-cycle, plus B -> A also exists)\n    connection1 = Connection(\n        source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)\n    )\n    connection2 = Connection(\n        source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)\n    )\n    connection3 = Connection(\n        source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)\n    )\n    connection4 = Connection(\n        source=node_b_id, sink=node_a_id, edge=create_socket_connection(4)\n    )\n\n    topology = Topology()\n    topology.add_node(node_a_id)\n    topology.add_node(node_b_id)\n    topology.add_node(node_c_id)\n    topology.add_connection(connection1)\n    topology.add_connection(connection2)\n    topology.add_connection(connection3)\n    topology.add_connection(connection4)\n\n    node_a_mem = create_node_memory(available_memory[0] * 1024)\n    node_b_mem = create_node_memory(available_memory[1] * 1024)\n    node_c_mem = create_node_memory(available_memory[2] * 1024)\n    node_memory = {\n        node_a_id: node_a_mem,\n        node_b_id: node_b_mem,\n        node_c_id: node_c_mem,\n    }\n\n    model_card = ModelCard(\n        model_id=ModelId(\"test-model\"),\n        n_layers=total_layers,\n        storage_size=Memory.from_kb(1000),\n        hidden_size=1000,\n        supports_tensor=True,\n        tasks=[ModelTask.TextGeneration],\n    )\n\n    cycles = topology.get_cycles()\n\n    # pick the 3-node cycle deterministically (cycle ordering can vary)\n    selected_cycle = next(cycle for cycle in cycles if len(cycle) == 3)\n\n    # act\n    shard_assignments = get_shard_assignments(\n        model_card, selected_cycle, Sharding.Pipeline, node_memory=node_memory\n    )\n\n    # assert\n    runner_id_a = shard_assignments.node_to_runner[node_a_id]\n    runner_id_b = shard_assignments.node_to_runner[node_b_id]\n    runner_id_c = shard_assignments.node_to_runner[node_c_id]\n\n    assert (\n        shard_assignments.runner_to_shard[runner_id_a].end_layer\n        - shard_assignments.runner_to_shard[runner_id_a].start_layer\n        == expected_layers[0]\n    )\n    assert (\n        shard_assignments.runner_to_shard[runner_id_b].end_layer\n        - shard_assignments.runner_to_shard[runner_id_b].start_layer\n        == expected_layers[1]\n    )\n    assert (\n        shard_assignments.runner_to_shard[runner_id_c].end_layer\n        - shard_assignments.runner_to_shard[runner_id_c].start_layer\n        == expected_layers[2]\n    )\n\n\ndef test_get_mlx_jaccl_coordinators():\n    # arrange\n    node_a_id = NodeId()\n    node_b_id = NodeId()\n    node_c_id = NodeId()\n\n    # fully connected (directed) between the 3 nodes\n    conn_a_b = Connection(\n        source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)\n    )\n    conn_b_a = Connection(\n        source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)\n    )\n    conn_b_c = Connection(\n        source=node_b_id, sink=node_c_id, edge=create_socket_connection(3)\n    )\n    conn_c_b = Connection(\n        source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)\n    )\n    conn_c_a = Connection(\n        source=node_c_id, sink=node_a_id, edge=create_socket_connection(5)\n    )\n    conn_a_c = Connection(\n        source=node_a_id, sink=node_c_id, edge=create_socket_connection(6)\n    )\n\n    network_a = NodeNetworkInfo(\n        interfaces=[\n            NetworkInterfaceInfo(name=\"en0\", ip_address=\"169.254.0.5\"),\n            NetworkInterfaceInfo(name=\"en0\", ip_address=\"169.254.0.2\"),\n        ]\n    )\n    network_b = NodeNetworkInfo(\n        interfaces=[\n            NetworkInterfaceInfo(name=\"en0\", ip_address=\"169.254.0.1\"),\n            NetworkInterfaceInfo(name=\"en0\", ip_address=\"169.254.0.4\"),\n        ]\n    )\n    network_c = NodeNetworkInfo(\n        interfaces=[\n            NetworkInterfaceInfo(name=\"en0\", ip_address=\"169.254.0.3\"),\n            NetworkInterfaceInfo(name=\"en0\", ip_address=\"169.254.0.6\"),\n        ]\n    )\n    node_network = {\n        node_a_id: network_a,\n        node_b_id: network_b,\n        node_c_id: network_c,\n    }\n\n    topology = Topology()\n    topology.add_node(node_a_id)\n    topology.add_node(node_b_id)\n    topology.add_node(node_c_id)\n\n    topology.add_connection(conn_a_b)\n    topology.add_connection(conn_b_a)\n    topology.add_connection(conn_b_c)\n    topology.add_connection(conn_c_b)\n    topology.add_connection(conn_c_a)\n    topology.add_connection(conn_a_c)\n\n    # act\n    coordinators = get_mlx_jaccl_coordinators(\n        node_a_id,\n        coordinator_port=5000,\n        cycle_digraph=topology,\n        node_network=node_network,\n    )\n\n    # assert\n    assert len(coordinators) == 3\n    assert node_a_id in coordinators\n    assert node_b_id in coordinators\n    assert node_c_id in coordinators\n\n    # All coordinators should have IP:PORT format\n    for node_id, coordinator in coordinators.items():\n        assert \":\" in coordinator, (\n            f\"Coordinator for {node_id} should have ':' separator\"\n        )\n\n    # Verify port is correct\n    for node_id, coordinator in coordinators.items():\n        assert coordinator.endswith(\":5000\"), (\n            f\"Coordinator for {node_id} should use port 5000\"\n        )\n\n    # Rank 0 (node_a) treats this as the listen socket so should listen on all IPs\n    assert coordinators[node_a_id].startswith(\"0.0.0.0:\"), (\n        \"Rank 0 node should use 0.0.0.0 as coordinator listen address\"\n    )\n\n    # Non-rank-0 nodes should use the specific IP from their connection to rank 0\n    # node_b uses the IP from conn_b_a (node_b -> node_a)\n    assert isinstance(conn_b_a.edge, SocketConnection)\n    assert (\n        coordinators[node_b_id] == f\"{conn_b_a.edge.sink_multiaddr.ip_address}:5000\"\n    ), \"node_b should use the IP from conn_b_a\"\n\n    # node_c uses the IP from conn_c_a (node_c -> node_a)\n    assert isinstance(conn_c_a.edge, SocketConnection)\n    assert coordinators[node_c_id] == (\n        f\"{conn_c_a.edge.sink_multiaddr.ip_address}:5000\"\n    ), \"node_c should use the IP from conn_c_a\"\n\n\nclass TestAllocateLayersProportionally:\n    def test_empty_node_list_raises(self):\n        with pytest.raises(ValueError, match=\"empty node list\"):\n            allocate_layers_proportionally(total_layers=10, memory_fractions=[])\n\n    def test_zero_layers_raises(self):\n        with pytest.raises(ValueError, match=\"need at least 1 layer per node\"):\n            allocate_layers_proportionally(total_layers=0, memory_fractions=[0.5, 0.5])\n\n    def test_negative_layers_raises(self):\n        with pytest.raises(ValueError, match=\"need at least 1 layer per node\"):\n            allocate_layers_proportionally(total_layers=-1, memory_fractions=[0.5, 0.5])\n\n    def test_fewer_layers_than_nodes_raises(self):\n        with pytest.raises(ValueError, match=\"need at least 1 layer per node\"):\n            allocate_layers_proportionally(\n                total_layers=2, memory_fractions=[0.33, 0.33, 0.34]\n            )\n\n    def test_equal_distribution(self):\n        result = allocate_layers_proportionally(\n            total_layers=12, memory_fractions=[0.25, 0.25, 0.25, 0.25]\n        )\n        assert result == [3, 3, 3, 3]\n        assert sum(result) == 12\n\n    def test_proportional_distribution(self):\n        result = allocate_layers_proportionally(\n            total_layers=12, memory_fractions=[0.25, 0.25, 0.50]\n        )\n        assert result == [3, 3, 6]\n        assert sum(result) == 12\n\n    def test_extreme_imbalance_ensures_minimum(self):\n        result = allocate_layers_proportionally(\n            total_layers=20, memory_fractions=[0.975, 0.0125, 0.0125]\n        )\n        assert all(layers >= 1 for layers in result)\n        assert sum(result) == 20\n        # Small nodes get minimum 1 layer\n        assert result == [18, 1, 1]\n\n    def test_single_node_gets_all_layers(self):\n        result = allocate_layers_proportionally(total_layers=10, memory_fractions=[1.0])\n        assert result == [10]\n\n    def test_minimum_viable_allocation(self):\n        result = allocate_layers_proportionally(\n            total_layers=3, memory_fractions=[0.33, 0.33, 0.34]\n        )\n        assert result == [1, 1, 1]\n        assert sum(result) == 3\n\n\ndef test_get_shard_assignments_insufficient_memory_raises():\n    \"\"\"Test that ValueError is raised when a node has insufficient memory for its layers.\"\"\"\n    node_a_id = NodeId()\n    node_b_id = NodeId()\n    node_c_id = NodeId()\n    topology = Topology()\n\n    # Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)\n    node_a_mem = create_node_memory(900 * 1024)\n    node_b_mem = create_node_memory(50 * 1024)\n    node_c_mem = create_node_memory(10 * 1024)  # Insufficient memory\n\n    topology.add_node(node_a_id)\n    topology.add_node(node_b_id)\n    topology.add_node(node_c_id)\n\n    conn_a_b = Connection(\n        source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)\n    )\n    conn_b_c = Connection(\n        source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)\n    )\n    conn_c_a = Connection(\n        source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)\n    )\n    conn_b_a = Connection(\n        source=node_b_id, sink=node_a_id, edge=create_socket_connection(3)\n    )\n    topology.add_connection(conn_a_b)\n    topology.add_connection(conn_b_c)\n    topology.add_connection(conn_c_a)\n    topology.add_connection(conn_b_a)\n\n    node_memory = {\n        node_a_id: node_a_mem,\n        node_b_id: node_b_mem,\n        node_c_id: node_c_mem,\n    }\n\n    model_card = ModelCard(\n        model_id=ModelId(\"test-model\"),\n        n_layers=20,\n        storage_size=Memory.from_kb(1000),\n        hidden_size=1000,\n        supports_tensor=True,\n        tasks=[ModelTask.TextGeneration],\n    )\n    cycles = topology.get_cycles()\n    selected_cycle = cycles[0]\n\n    with pytest.raises(ValueError, match=\"insufficient memory\"):\n        get_shard_assignments(\n            model_card, selected_cycle, Sharding.Pipeline, node_memory\n        )\n\n\nclass TestCfgParallelPlacement:\n    def _create_ring_topology(self, node_ids: list[NodeId]) -> Topology:\n        topology = Topology()\n        for node_id in node_ids:\n            topology.add_node(node_id)\n\n        for i, node_id in enumerate(node_ids):\n            next_node = node_ids[(i + 1) % len(node_ids)]\n            conn = Connection(\n                source=node_id,\n                sink=next_node,\n                edge=create_socket_connection(i + 1),\n            )\n            topology.add_connection(conn)\n\n        return topology\n\n    def test_two_nodes_cfg_model_uses_cfg_parallel(self):\n        \"\"\"Two nodes with CFG model should use CFG parallel (no pipeline).\"\"\"\n        node_a = NodeId()\n        node_b = NodeId()\n\n        topology = self._create_ring_topology([node_a, node_b])\n        cycles = [c for c in topology.get_cycles() if len(c) == 2]\n        cycle = cycles[0]\n\n        node_memory = {\n            node_a: create_node_memory(1000 * 1024),\n            node_b: create_node_memory(1000 * 1024),\n        }\n\n        model_card = ModelCard(\n            model_id=ModelId(\"qwen-image-test\"),\n            n_layers=60,\n            storage_size=Memory.from_kb(1000),\n            hidden_size=1,\n            supports_tensor=False,\n            uses_cfg=True,\n            tasks=[ModelTask.TextToImage],\n        )\n\n        assignments = get_shard_assignments_for_pipeline_parallel(\n            model_card, cycle, node_memory\n        )\n\n        shards = list(assignments.runner_to_shard.values())\n        assert len(shards) == 2\n\n        # CFG models should get CfgShardMetadata\n        for shard in shards:\n            assert isinstance(shard, CfgShardMetadata)\n            # Both nodes should have all layers (no pipeline split)\n            assert shard.start_layer == 0\n            assert shard.end_layer == 60\n            assert shard.cfg_world_size == 2\n            # Each node is the only stage in its pipeline group\n            assert shard.pipeline_world_size == 1\n            assert shard.pipeline_rank == 0\n\n        cfg_ranks = sorted(\n            s.cfg_rank for s in shards if isinstance(s, CfgShardMetadata)\n        )\n        assert cfg_ranks == [0, 1]\n\n    def test_four_nodes_cfg_model_uses_hybrid(self):\n        \"\"\"Four nodes with CFG model should use 2 CFG groups x 2 pipeline stages.\"\"\"\n        nodes = [NodeId() for _ in range(4)]\n\n        topology = self._create_ring_topology(nodes)\n        cycles = [c for c in topology.get_cycles() if len(c) == 4]\n        cycle = cycles[0]\n\n        node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}\n\n        model_card = ModelCard(\n            model_id=ModelId(\"qwen-image-test\"),\n            n_layers=60,\n            storage_size=Memory.from_kb(1000),\n            hidden_size=1,\n            supports_tensor=False,\n            uses_cfg=True,\n            tasks=[ModelTask.TextToImage],\n        )\n\n        assignments = get_shard_assignments_for_pipeline_parallel(\n            model_card, cycle, node_memory\n        )\n\n        shards = list(assignments.runner_to_shard.values())\n        assert len(shards) == 4\n\n        # CFG models should get CfgShardMetadata\n        for shard in shards:\n            assert isinstance(shard, CfgShardMetadata)\n            assert shard.cfg_world_size == 2\n            assert shard.pipeline_world_size == 2\n            assert shard.pipeline_rank in [0, 1]\n\n        # Check we have 2 nodes in each CFG group\n        cfg_0_shards = [\n            s for s in shards if isinstance(s, CfgShardMetadata) and s.cfg_rank == 0\n        ]\n        cfg_1_shards = [\n            s for s in shards if isinstance(s, CfgShardMetadata) and s.cfg_rank == 1\n        ]\n        assert len(cfg_0_shards) == 2\n        assert len(cfg_1_shards) == 2\n\n        # Both CFG groups should have the same layer assignments\n        cfg_0_layers = [(s.start_layer, s.end_layer) for s in cfg_0_shards]\n        cfg_1_layers = [(s.start_layer, s.end_layer) for s in cfg_1_shards]\n        assert sorted(cfg_0_layers) == sorted(cfg_1_layers)\n\n    def test_three_nodes_cfg_model_uses_sequential_cfg(self):\n        \"\"\"Three nodes (odd) with CFG model should use sequential CFG (PipelineShardMetadata).\"\"\"\n        nodes = [NodeId() for _ in range(3)]\n\n        topology = self._create_ring_topology(nodes)\n        cycles = [c for c in topology.get_cycles() if len(c) == 3]\n        cycle = cycles[0]\n\n        node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}\n\n        model_card = ModelCard(\n            model_id=ModelId(\"qwen-image-test\"),\n            n_layers=60,\n            storage_size=Memory.from_kb(1000),\n            hidden_size=1,\n            supports_tensor=False,\n            uses_cfg=True,\n            tasks=[ModelTask.TextToImage],\n        )\n\n        assignments = get_shard_assignments_for_pipeline_parallel(\n            model_card, cycle, node_memory\n        )\n\n        shards = list(assignments.runner_to_shard.values())\n        assert len(shards) == 3\n\n        # Odd node count with CFG model falls back to PipelineShardMetadata (sequential CFG)\n        for shard in shards:\n            assert isinstance(shard, PipelineShardMetadata)\n\n    def test_two_nodes_non_cfg_model_uses_pipeline(self):\n        \"\"\"Two nodes with non-CFG model should use pure pipeline (PipelineShardMetadata).\"\"\"\n        node_a = NodeId()\n        node_b = NodeId()\n\n        topology = self._create_ring_topology([node_a, node_b])\n        cycles = [c for c in topology.get_cycles() if len(c) == 2]\n        cycle = cycles[0]\n\n        node_memory = {\n            node_a: create_node_memory(1000 * 1024),\n            node_b: create_node_memory(1000 * 1024),\n        }\n\n        model_card = ModelCard(\n            model_id=ModelId(\"flux-test\"),\n            n_layers=57,\n            storage_size=Memory.from_kb(1000),\n            hidden_size=1,\n            supports_tensor=False,\n            uses_cfg=False,  # Non-CFG model\n            tasks=[ModelTask.TextToImage],\n        )\n\n        assignments = get_shard_assignments_for_pipeline_parallel(\n            model_card, cycle, node_memory\n        )\n\n        shards = list(assignments.runner_to_shard.values())\n        assert len(shards) == 2\n\n        # Non-CFG models should get PipelineShardMetadata\n        for shard in shards:\n            assert isinstance(shard, PipelineShardMetadata)\n\n        # Should have actual layer sharding (pipeline)\n        layer_ranges = sorted(\n            (s.start_layer, s.end_layer)\n            for s in shards\n            if isinstance(s, PipelineShardMetadata)\n        )\n        # First shard starts at 0, last shard ends at 57\n        assert layer_ranges[0][0] == 0\n        assert layer_ranges[-1][1] == 57\n"
  },
  {
    "path": "src/exo/master/tests/test_topology.py",
    "content": "import pytest\n\nfrom exo.shared.topology import Topology\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.multiaddr import Multiaddr\nfrom exo.shared.types.topology import Connection, SocketConnection\n\n\n@pytest.fixture\ndef topology() -> Topology:\n    return Topology()\n\n\n@pytest.fixture\ndef socket_connection() -> SocketConnection:\n    return SocketConnection(\n        sink_multiaddr=Multiaddr(address=\"/ip4/127.0.0.1/tcp/1235\"),\n    )\n\n\ndef test_add_node(topology: Topology):\n    # arrange\n    node_id = NodeId()\n\n    # act\n    topology.add_node(node_id)\n\n    # assert\n    assert topology.node_is_leaf(node_id)\n\n\ndef test_add_connection(topology: Topology, socket_connection: SocketConnection):\n    # arrange\n    node_a = NodeId()\n    node_b = NodeId()\n    connection = Connection(source=node_a, sink=node_b, edge=socket_connection)\n\n    topology.add_node(node_a)\n    topology.add_node(node_b)\n    topology.add_connection(connection)\n\n    # act\n    data = list(topology.list_connections())\n\n    # assert\n    assert data == [connection]\n\n    assert topology.node_is_leaf(node_a)\n    assert topology.node_is_leaf(node_b)\n\n\ndef test_remove_connection_still_connected(\n    topology: Topology, socket_connection: SocketConnection\n):\n    # arrange\n    node_a = NodeId()\n    node_b = NodeId()\n    conn = Connection(source=node_a, sink=node_b, edge=socket_connection)\n\n    topology.add_node(node_a)\n    topology.add_node(node_b)\n    topology.add_connection(conn)\n\n    # act\n    topology.remove_connection(conn)\n\n    # assert\n    assert list(topology.get_all_connections_between(node_a, node_b)) == []\n\n\ndef test_remove_node_still_connected(\n    topology: Topology, socket_connection: SocketConnection\n):\n    # arrange\n    node_a = NodeId()\n    node_b = NodeId()\n    conn = Connection(source=node_a, sink=node_b, edge=socket_connection)\n\n    topology.add_node(node_a)\n    topology.add_node(node_b)\n    topology.add_connection(conn)\n    assert list(topology.out_edges(node_a)) == [conn]\n\n    # act\n    topology.remove_node(node_b)\n\n    # assert\n    assert list(topology.out_edges(node_a)) == []\n\n\ndef test_list_nodes(topology: Topology, socket_connection: SocketConnection):\n    # arrange\n    node_a = NodeId()\n    node_b = NodeId()\n    conn = Connection(source=node_a, sink=node_b, edge=socket_connection)\n\n    topology.add_node(node_a)\n    topology.add_node(node_b)\n    topology.add_connection(conn)\n    assert list(topology.out_edges(node_a)) == [conn]\n\n    # act\n    nodes = list(topology.list_nodes())\n\n    # assert\n    assert len(nodes) == 2\n    assert all(isinstance(node, NodeId) for node in nodes)\n    assert set(node for node in nodes) == set([node_a, node_b])\n"
  },
  {
    "path": "src/exo/routing/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/routing/connection_message.py",
    "content": "from exo_pyo3_bindings import PyFromSwarm\n\nfrom exo.shared.types.common import NodeId\nfrom exo.utils.pydantic_ext import CamelCaseModel\n\n\"\"\"Serialisable types for Connection Updates/Messages\"\"\"\n\n\nclass ConnectionMessage(CamelCaseModel):\n    node_id: NodeId\n    connected: bool\n\n    @classmethod\n    def from_update(cls, update: PyFromSwarm.Connection) -> \"ConnectionMessage\":\n        return cls(node_id=NodeId(update.peer_id), connected=update.connected)\n"
  },
  {
    "path": "src/exo/routing/event_router.py",
    "content": "from dataclasses import dataclass, field\nfrom random import random\n\nimport anyio\nfrom anyio import BrokenResourceError, ClosedResourceError\nfrom anyio.abc import CancelScope\nfrom loguru import logger\n\nfrom exo.shared.types.commands import ForwarderCommand, RequestEventLog\nfrom exo.shared.types.common import SessionId, SystemId\nfrom exo.shared.types.events import (\n    Event,\n    EventId,\n    GlobalForwarderEvent,\n    IndexedEvent,\n    LocalForwarderEvent,\n)\nfrom exo.utils.channels import Receiver, Sender, channel\nfrom exo.utils.event_buffer import OrderedBuffer\nfrom exo.utils.task_group import TaskGroup\n\n\n@dataclass\nclass EventRouter:\n    session_id: SessionId\n    command_sender: Sender[ForwarderCommand]\n    external_inbound: Receiver[GlobalForwarderEvent]\n    external_outbound: Sender[LocalForwarderEvent]\n    _system_id: SystemId = field(init=False, default_factory=SystemId)\n    internal_outbound: list[Sender[IndexedEvent]] = field(\n        init=False, default_factory=list\n    )\n    event_buffer: OrderedBuffer[Event] = field(\n        init=False, default_factory=OrderedBuffer\n    )\n    out_for_delivery: dict[EventId, tuple[float, LocalForwarderEvent]] = field(\n        init=False, default_factory=dict\n    )\n    _tg: TaskGroup = field(init=False, default_factory=TaskGroup)\n\n    _nack_cancel_scope: CancelScope | None = field(init=False, default=None)\n    _nack_attempts: int = field(init=False, default=0)\n    _nack_base_seconds: float = field(init=False, default=0.5)\n    _nack_cap_seconds: float = field(init=False, default=10.0)\n\n    async def run(self):\n        try:\n            async with self._tg as tg:\n                tg.start_soon(self._run_ext_in)\n                tg.start_soon(self._simple_retry)\n        finally:\n            self.external_outbound.close()\n            for send in self.internal_outbound:\n                send.close()\n\n    # can make this better in future\n    async def _simple_retry(self):\n        while True:\n            await anyio.sleep(1 + random())\n            # list here is a shallow clone for shared mutation\n            for e_id, (time, event) in list(self.out_for_delivery.items()):\n                if anyio.current_time() > time + 5:\n                    self.out_for_delivery[e_id] = (anyio.current_time(), event)\n                    await self.external_outbound.send(event)\n\n    def sender(self) -> Sender[Event]:\n        send, recv = channel[Event]()\n        if self._tg.is_running():\n            self._tg.start_soon(self._ingest, SystemId(), recv)\n        else:\n            self._tg.queue(self._ingest, SystemId(), recv)\n        return send\n\n    def receiver(self) -> Receiver[IndexedEvent]:\n        send, recv = channel[IndexedEvent]()\n        self.internal_outbound.append(send)\n        return recv\n\n    def shutdown(self) -> None:\n        self._tg.cancel_tasks()\n\n    async def _ingest(self, system_id: SystemId, recv: Receiver[Event]):\n        idx = 0\n        with recv as events:\n            async for event in events:\n                f_ev = LocalForwarderEvent(\n                    origin_idx=idx,\n                    origin=system_id,\n                    session=self.session_id,\n                    event=event,\n                )\n                idx += 1\n                await self.external_outbound.send(f_ev)\n                self.out_for_delivery[event.event_id] = (anyio.current_time(), f_ev)\n\n    async def _run_ext_in(self):\n        buf = OrderedBuffer[Event]()\n        with self.external_inbound as events:\n            async for event in events:\n                if event.session != self.session_id:\n                    continue\n                if event.origin != self.session_id.master_node_id:\n                    continue\n\n                buf.ingest(event.origin_idx, event.event)\n                event_id = event.event.event_id\n                if event_id in self.out_for_delivery:\n                    self.out_for_delivery.pop(event_id)\n\n                drained = buf.drain_indexed()\n                if drained:\n                    self._nack_attempts = 0\n                    if self._nack_cancel_scope:\n                        self._nack_cancel_scope.cancel()\n\n                if not drained and (\n                    self._nack_cancel_scope is None\n                    or self._nack_cancel_scope.cancel_called\n                ):\n                    # Request the next index.\n                    self._tg.start_soon(self._nack_request, buf.next_idx_to_release)\n                    continue\n\n                for idx, event in drained:\n                    to_clear = set[int]()\n                    for i, sender in enumerate(self.internal_outbound):\n                        try:\n                            await sender.send(IndexedEvent(idx=idx, event=event))\n                        except (ClosedResourceError, BrokenResourceError):\n                            to_clear.add(i)\n                    for i in sorted(to_clear, reverse=True):\n                        self.internal_outbound.pop(i)\n\n    async def _nack_request(self, since_idx: int) -> None:\n        # We request all events after (and including) the missing index.\n        # This function is started whenever we receive an event that is out of sequence.\n        # It is cancelled as soon as we receiver an event that is in sequence.\n\n        if since_idx < 0:\n            logger.warning(f\"Negative value encountered for nack request {since_idx=}\")\n            since_idx = 0\n\n        with CancelScope() as scope:\n            self._nack_cancel_scope = scope\n            delay: float = self._nack_base_seconds * (2.0**self._nack_attempts)\n            delay = min(self._nack_cap_seconds, delay)\n            self._nack_attempts += 1\n            try:\n                await anyio.sleep(delay)\n                logger.info(\n                    f\"Nack attempt {self._nack_attempts}: Requesting Event Log from {since_idx}\"\n                )\n                await self.command_sender.send(\n                    ForwarderCommand(\n                        origin=self._system_id,\n                        command=RequestEventLog(since_idx=since_idx),\n                    )\n                )\n            finally:\n                if self._nack_cancel_scope is scope:\n                    self._nack_cancel_scope = None\n"
  },
  {
    "path": "src/exo/routing/router.py",
    "content": "from copy import copy\nfrom itertools import count\nfrom math import inf\nfrom os import PathLike\nfrom pathlib import Path\nfrom typing import cast\n\nfrom anyio import (\n    BrokenResourceError,\n    ClosedResourceError,\n    move_on_after,\n    sleep_forever,\n)\nfrom exo_pyo3_bindings import (\n    AllQueuesFullError,\n    Keypair,\n    MessageTooLargeError,\n    NetworkingHandle,\n    NoPeersSubscribedToTopicError,\n    PyFromSwarm,\n)\nfrom filelock import FileLock\nfrom loguru import logger\n\nfrom exo.shared.constants import EXO_NODE_ID_KEYPAIR\nfrom exo.utils.channels import Receiver, Sender, channel\nfrom exo.utils.pydantic_ext import CamelCaseModel\nfrom exo.utils.task_group import TaskGroup\n\nfrom .connection_message import ConnectionMessage\nfrom .topics import CONNECTION_MESSAGES, PublishPolicy, TypedTopic\n\n\n# A significant current limitation of the TopicRouter is that it is not capable\n# of preventing feedback, as it does not ask for a system id so cannot tell\n# which message is coming/going to which system.\n# This is currently only relevant for Election\nclass TopicRouter[T: CamelCaseModel]:\n    def __init__(\n        self,\n        topic: TypedTopic[T],\n        networking_sender: Sender[tuple[str, bytes]],\n        max_buffer_size: float = inf,\n    ):\n        self.topic: TypedTopic[T] = topic\n        self.senders: set[Sender[T]] = set()\n        send, recv = channel[T]()\n        self.receiver: Receiver[T] = recv\n        self._sender: Sender[T] = send\n        self.networking_sender: Sender[tuple[str, bytes]] = networking_sender\n\n    async def run(self):\n        logger.debug(f\"Topic Router {self.topic} ready to send\")\n        with self.receiver as items:\n            async for item in items:\n                # Check if we should send to network\n                if (\n                    len(self.senders) == 0\n                    and self.topic.publish_policy is PublishPolicy.Minimal\n                ):\n                    await self._send_out(item)\n                    continue\n                if self.topic.publish_policy is PublishPolicy.Always:\n                    await self._send_out(item)\n                # Then publish to all senders\n                await self.publish(item)\n\n    async def shutdown(self):\n        logger.debug(f\"Shutting down Topic Router {self.topic}\")\n        # Close all the things!\n        for sender in self.senders:\n            sender.close()\n        self._sender.close()\n        self.receiver.close()\n\n    async def publish(self, item: T):\n        \"\"\"\n        Publish item T on this topic to all senders.\n        NB: this sends to ALL receivers, potentially including receivers held by the object doing the sending.\n        You should handle your own output if you hold a sender + receiver pair.\n        \"\"\"\n        to_clear: set[Sender[T]] = set()\n        for sender in copy(self.senders):\n            try:\n                await sender.send(item)\n            except (ClosedResourceError, BrokenResourceError):\n                to_clear.add(sender)\n        self.senders -= to_clear\n\n    async def publish_bytes(self, data: bytes):\n        await self.publish(self.topic.deserialize(data))\n\n    def new_sender(self) -> Sender[T]:\n        return self._sender.clone()\n\n    async def _send_out(self, item: T):\n        logger.trace(f\"TopicRouter {self.topic.topic} sending {item}\")\n        await self.networking_sender.send(\n            (str(self.topic.topic), self.topic.serialize(item))\n        )\n\n\nclass Router:\n    @classmethod\n    def create(cls, identity: Keypair) -> \"Router\":\n        return cls(handle=NetworkingHandle(identity))\n\n    def __init__(self, handle: NetworkingHandle):\n        self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {}\n        send, recv = channel[tuple[str, bytes]]()\n        self.networking_receiver: Receiver[tuple[str, bytes]] = recv\n        self._net: NetworkingHandle = handle\n        self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send\n        self._id_count = count()\n        self._tg: TaskGroup = TaskGroup()\n\n    async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):\n        send = self._tmp_networking_sender\n        if send:\n            self._tmp_networking_sender = None\n        else:\n            send = self.networking_receiver.clone_sender()\n        router = TopicRouter[T](topic, send)\n        self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)\n        if self._tg.is_running():\n            await self._networking_subscribe(topic.topic)\n\n    def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:\n        router = self.topic_routers.get(topic.topic, None)\n        # There's gotta be a way to do this without THIS many asserts\n        assert router is not None\n        assert router.topic == topic\n        sender = cast(TopicRouter[T], router).new_sender()\n        return sender\n\n    def receiver[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Receiver[T]:\n        router = self.topic_routers.get(topic.topic, None)\n        # There's gotta be a way to do this without THIS many asserts\n\n        assert router is not None\n        assert router.topic == topic\n        assert router.topic.model_type == topic.model_type\n\n        send, recv = channel[T]()\n        router.senders.add(cast(Sender[CamelCaseModel], send))\n\n        return recv\n\n    async def run(self):\n        logger.debug(\"Starting Router\")\n        try:\n            async with self._tg as tg:\n                for topic in self.topic_routers:\n                    router = self.topic_routers[topic]\n                    tg.start_soon(router.run)\n                tg.start_soon(self._networking_recv)\n                tg.start_soon(self._networking_publish)\n                # subscribe to pending topics\n                for topic in self.topic_routers:\n                    await self._networking_subscribe(topic)\n                # Router only shuts down if you cancel it.\n                await sleep_forever()\n        finally:\n            with move_on_after(1, shield=True):\n                for topic in self.topic_routers:\n                    await self._networking_unsubscribe(str(topic))\n\n    async def shutdown(self):\n        logger.debug(\"Shutting down Router\")\n        self._tg.cancel_tasks()\n\n    async def _networking_subscribe(self, topic: str):\n        await self._net.gossipsub_subscribe(topic)\n        logger.info(f\"Subscribed to {topic}\")\n\n    async def _networking_unsubscribe(self, topic: str):\n        await self._net.gossipsub_unsubscribe(topic)\n        logger.info(f\"Unsubscribed from {topic}\")\n\n    async def _networking_recv(self):\n        try:\n            while True:\n                from_swarm = await self._net.recv()\n                logger.debug(from_swarm)\n                match from_swarm:\n                    case PyFromSwarm.Message(origin, topic, data):\n                        logger.trace(\n                            f\"Received message on {topic} from {origin} with payload {data}\"\n                        )\n                        if topic not in self.topic_routers:\n                            logger.warning(\n                                f\"Received message on unknown or inactive topic {topic}\"\n                            )\n                            continue\n                        router = self.topic_routers[topic]\n                        await router.publish_bytes(data)\n                    case PyFromSwarm.Connection():\n                        message = ConnectionMessage.from_update(from_swarm)\n                        logger.trace(\n                            f\"Received message on connection_messages with payload {message}\"\n                        )\n                        if CONNECTION_MESSAGES.topic in self.topic_routers:\n                            router = self.topic_routers[CONNECTION_MESSAGES.topic]\n                            assert router.topic.model_type == ConnectionMessage\n                            router = cast(TopicRouter[ConnectionMessage], router)\n                            await router.publish(message)\n                    case _:\n                        logger.critical(\n                            \"failed to exhaustively check FromSwarm messages - logic error\"\n                        )\n        except Exception as exception:\n            logger.opt(exception=exception).error(\n                \"Gossipsub receive loop terminated unexpectedly\"\n            )\n            raise\n\n    async def _networking_publish(self):\n        with self.networking_receiver as networked_items:\n            async for topic, data in networked_items:\n                try:\n                    logger.trace(f\"Sending message on {topic} with payload {data}\")\n                    if len(data) > 1024 * 1024:\n                        logger.warning(\n                            \"Sending overlarge payload, network performance may be temporarily degraded\"\n                        )\n                    await self._net.gossipsub_publish(topic, data)\n                except NoPeersSubscribedToTopicError:\n                    pass\n                except AllQueuesFullError:\n                    logger.warning(f\"All peer queues full, dropping message on {topic}\")\n                except MessageTooLargeError:\n                    logger.warning(\n                        f\"Message too large for gossipsub on {topic} ({len(data)} bytes), dropping\"\n                    )\n\n\ndef get_node_id_keypair(\n    path: str | bytes | PathLike[str] | PathLike[bytes] = EXO_NODE_ID_KEYPAIR,\n) -> Keypair:\n    \"\"\"\n    Obtains the :class:`Keypair` associated with this node-ID.\n    Obtain the :class:`PeerId` by from it.\n    \"\"\"\n    # TODO(evan): bring back node id persistence once we figure out how to deal with duplicates\n    return Keypair.generate()\n\n    def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:\n        return Path(str(path) + \".lock\")\n\n    # operate with cross-process lock to avoid race conditions\n    with FileLock(lock_path(path)):\n        with open(path, \"a+b\") as f:  # opens in append-mode => starts at EOF\n            # if non-zero EOF, then file exists => use to get node-ID\n            if f.tell() != 0:\n                f.seek(0)  # go to start & read protobuf-encoded bytes\n                protobuf_encoded = f.read()\n\n                try:  # if decoded successfully, save & return\n                    return Keypair.from_bytes(protobuf_encoded)\n                except ValueError as e:  # on runtime error, assume corrupt file\n                    logger.warning(f\"Encountered error when trying to get keypair: {e}\")\n\n        # if no valid credentials, create new ones and persist\n        with open(path, \"w+b\") as f:\n            keypair = Keypair.generate()\n            f.write(keypair.to_bytes())\n            return keypair\n"
  },
  {
    "path": "src/exo/routing/tests/test_event_buffer.py",
    "content": "import pytest\n\nfrom exo.shared.types.events import Event, TestEvent\nfrom exo.utils.event_buffer import OrderedBuffer\n\n\ndef make_indexed_event(idx: int) -> tuple[int, Event]:\n    \"\"\"Factory function to create a unique ForwarderEvent for a given index.\"\"\"\n    return (idx, TestEvent())\n\n\n@pytest.fixture\ndef buffer() -> OrderedBuffer[Event]:\n    \"\"\"Provides a clean instance of OrderedBuffer[Event] for each test.\"\"\"\n    return OrderedBuffer[Event]()\n\n\n@pytest.mark.asyncio\nasync def test_initial_state(buffer: OrderedBuffer[Event]):\n    \"\"\"Tests that a new buffer is empty and starts at index 1.\"\"\"\n    assert buffer.next_idx_to_release == 0\n    assert not buffer.store\n    assert buffer.drain() == []\n\n\n@pytest.mark.asyncio\nasync def test_ingest_and_drain_sequential_events(buffer: OrderedBuffer[Event]):\n    \"\"\"Tests ingesting and draining a simple, ordered sequence of events.\"\"\"\n    events = [make_indexed_event(0), make_indexed_event(1), make_indexed_event(2)]\n    [buffer.ingest(*ev) for ev in events]\n\n    drained_events = buffer.drain_indexed()\n    assert drained_events == events\n    assert buffer.next_idx_to_release == 3\n    assert not buffer.store\n\n\n@pytest.mark.asyncio\nasync def test_ingest_out_of_order_events(buffer: OrderedBuffer[Event]):\n    \"\"\"Tests that out-of-order events are buffered and drained in the correct sequence.\"\"\"\n    event1 = make_indexed_event(0)\n    event2 = make_indexed_event(1)\n    event3 = make_indexed_event(2)\n\n    buffer.ingest(*event3)\n    buffer.ingest(*event1)\n    buffer.ingest(*event2)\n\n    drained_events = buffer.drain_indexed()\n    assert drained_events == [event1, event2, event3]\n    assert buffer.next_idx_to_release == 3\n\n\n@pytest.mark.asyncio\nasync def test_drain_with_gap_in_sequence(buffer: OrderedBuffer[Event]):\n    \"\"\"Tests that draining stops when there is a gap in the event indices.\"\"\"\n    event1 = make_indexed_event(0)\n    event3 = make_indexed_event(2)\n\n    buffer.ingest(*event1)\n    buffer.ingest(*event3)\n\n    drained_events = buffer.drain_indexed()\n    assert drained_events == [event1]\n    assert buffer.next_idx_to_release == 1\n\n    assert buffer.drain() == []\n    assert 2 in buffer.store\n\n\n@pytest.mark.asyncio\nasync def test_fill_gap_and_drain_remaining(buffer: OrderedBuffer[Event]):\n    \"\"\"Tests that once a gap is filled, the rest of the sequence is drained.\"\"\"\n    event0 = make_indexed_event(0)\n    event2 = make_indexed_event(2)\n    buffer.ingest(*event0)\n    buffer.ingest(*event2)\n\n    buffer.drain()\n    assert buffer.next_idx_to_release == 1\n\n    event1 = make_indexed_event(1)\n    buffer.ingest(*event1)\n\n    drained_events = buffer.drain_indexed()\n    assert [e[0] for e in drained_events] == [1, 2]\n    assert buffer.next_idx_to_release == 3\n\n\n@pytest.mark.asyncio\nasync def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]):\n    \"\"\"Tests that if multiple events for the same index are ingested, the first one wins.\"\"\"\n    event2_first = make_indexed_event(1)\n    event2_second = (1, TestEvent())\n\n    buffer.ingest(*make_indexed_event(0))\n    buffer.ingest(*event2_first)\n\n    with pytest.raises(AssertionError):\n        buffer.ingest(*event2_second)  # This duplicate should be ignored\n\n    drained = buffer.drain_indexed()\n    assert len(drained) == 2\n\n    assert drained[1][1].event_id == event2_first[1].event_id\n    assert drained[1][1].event_id != event2_second[1].event_id\n\n\n@pytest.mark.asyncio\nasync def test_ingest_drops_stale_events(buffer: OrderedBuffer[Event]):\n    \"\"\"Tests that events with an index lower than next_idx_to_release are dropped.\"\"\"\n    buffer.ingest(*make_indexed_event(0))\n    buffer.ingest(*make_indexed_event(1))\n    buffer.drain()\n\n    assert buffer.next_idx_to_release == 2\n\n    stale_event1 = make_indexed_event(0)\n    stale_event2 = make_indexed_event(1)\n    buffer.ingest(*stale_event1)\n    buffer.ingest(*stale_event2)\n\n    assert not buffer.store\n    assert buffer.drain() == []\n\n\n@pytest.mark.asyncio\nasync def test_drain_and_ingest_with_new_sequence(buffer: OrderedBuffer[Event]):\n    \"\"\"Tests reusing the buffer after it has been fully drained.\"\"\"\n    buffer.ingest(*make_indexed_event(0))\n    buffer.ingest(*make_indexed_event(1))\n    buffer.drain()\n\n    assert buffer.next_idx_to_release == 2\n    assert not buffer.store\n\n    buffer.ingest(*make_indexed_event(4))\n    buffer.ingest(*make_indexed_event(2))\n\n    drained = buffer.drain_indexed()\n    assert [e[0] for e in drained] == [2]\n    assert buffer.next_idx_to_release == 3\n    assert 4 in buffer.store\n"
  },
  {
    "path": "src/exo/routing/topics.py",
    "content": "from dataclasses import dataclass\nfrom enum import Enum\n\nfrom exo.routing.connection_message import ConnectionMessage\nfrom exo.shared.election import ElectionMessage\nfrom exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand\nfrom exo.shared.types.events import (\n    GlobalForwarderEvent,\n    LocalForwarderEvent,\n)\nfrom exo.utils.pydantic_ext import CamelCaseModel\n\n\nclass PublishPolicy(str, Enum):\n    Never = \"Never\"\n    \"\"\"Never publish to the network - this is a local message\"\"\"\n    Minimal = \"Minimal\"\n    \"\"\"Only publish when there is no local receiver for this type of message\"\"\"\n    Always = \"Always\"\n    \"\"\"Always publish to the network\"\"\"\n\n\n@dataclass  # (frozen=True)\nclass TypedTopic[T: CamelCaseModel]:\n    topic: str\n    publish_policy: PublishPolicy\n\n    model_type: type[\n        T\n    ]  # This can be worked around with evil type hacking, see https://stackoverflow.com/a/71720366 - I don't think it's necessary here.\n\n    @staticmethod\n    def serialize(t: T) -> bytes:\n        return t.model_dump_json().encode(\"utf-8\")\n\n    def deserialize(self, b: bytes) -> T:\n        return self.model_type.model_validate_json(b.decode(\"utf-8\"))\n\n\nGLOBAL_EVENTS = TypedTopic(\"global_events\", PublishPolicy.Always, GlobalForwarderEvent)\nLOCAL_EVENTS = TypedTopic(\"local_events\", PublishPolicy.Always, LocalForwarderEvent)\nCOMMANDS = TypedTopic(\"commands\", PublishPolicy.Always, ForwarderCommand)\nELECTION_MESSAGES = TypedTopic(\n    \"election_messages\", PublishPolicy.Always, ElectionMessage\n)\nCONNECTION_MESSAGES = TypedTopic(\n    \"connection_messages\", PublishPolicy.Never, ConnectionMessage\n)\nDOWNLOAD_COMMANDS = TypedTopic(\n    \"download_commands\", PublishPolicy.Always, ForwarderDownloadCommand\n)\n"
  },
  {
    "path": "src/exo/shared/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/shared/apply.py",
    "content": "import copy\nfrom collections.abc import Mapping, Sequence\nfrom datetime import datetime\n\nfrom loguru import logger\n\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.events import (\n    ChunkGenerated,\n    Event,\n    IndexedEvent,\n    InputChunkReceived,\n    InstanceCreated,\n    InstanceDeleted,\n    NodeDownloadProgress,\n    NodeGatheredInfo,\n    NodeTimedOut,\n    RunnerStatusUpdated,\n    TaskAcknowledged,\n    TaskCreated,\n    TaskDeleted,\n    TaskFailed,\n    TaskStatusUpdated,\n    TestEvent,\n    TopologyEdgeCreated,\n    TopologyEdgeDeleted,\n    TracesCollected,\n    TracesMerged,\n)\nfrom exo.shared.types.profiling import (\n    NodeIdentity,\n    NodeNetworkInfo,\n    NodeRdmaCtlStatus,\n    NodeThunderboltInfo,\n    ThunderboltBridgeStatus,\n)\nfrom exo.shared.types.state import State\nfrom exo.shared.types.tasks import Task, TaskId, TaskStatus\nfrom exo.shared.types.topology import Connection, RDMAConnection\nfrom exo.shared.types.worker.downloads import DownloadProgress\nfrom exo.shared.types.worker.instances import Instance, InstanceId\nfrom exo.shared.types.worker.runners import RunnerId, RunnerShutdown, RunnerStatus\nfrom exo.utils.info_gatherer.info_gatherer import (\n    MacmonMetrics,\n    MacThunderboltConnections,\n    MacThunderboltIdentifiers,\n    MemoryUsage,\n    MiscData,\n    NodeConfig,\n    NodeDiskUsage,\n    NodeNetworkInterfaces,\n    RdmaCtlStatus,\n    StaticNodeInformation,\n    ThunderboltBridgeInfo,\n)\n\n\ndef event_apply(event: Event, state: State) -> State:\n    \"\"\"Apply an event to state.\"\"\"\n    match event:\n        case (\n            TestEvent()\n            | ChunkGenerated()\n            | TaskAcknowledged()\n            | InputChunkReceived()\n            | TracesCollected()\n            | TracesMerged()\n        ):  # Pass-through events that don't modify state\n            return state\n        case InstanceCreated():\n            return apply_instance_created(event, state)\n        case InstanceDeleted():\n            return apply_instance_deleted(event, state)\n        case NodeTimedOut():\n            return apply_node_timed_out(event, state)\n        case NodeDownloadProgress():\n            return apply_node_download_progress(event, state)\n        case NodeGatheredInfo():\n            return apply_node_gathered_info(event, state)\n        case RunnerStatusUpdated():\n            return apply_runner_status_updated(event, state)\n        case TaskCreated():\n            return apply_task_created(event, state)\n        case TaskDeleted():\n            return apply_task_deleted(event, state)\n        case TaskFailed():\n            return apply_task_failed(event, state)\n        case TaskStatusUpdated():\n            return apply_task_status_updated(event, state)\n        case TopologyEdgeCreated():\n            return apply_topology_edge_created(event, state)\n        case TopologyEdgeDeleted():\n            return apply_topology_edge_deleted(event, state)\n\n\ndef apply(state: State, event: IndexedEvent) -> State:\n    # Just to test that events are only applied in correct order\n    if state.last_event_applied_idx != event.idx - 1:\n        logger.warning(\n            f\"Expected event {state.last_event_applied_idx + 1} but received {event.idx}\"\n        )\n    assert state.last_event_applied_idx == event.idx - 1\n    new_state: State = event_apply(event.event, state)\n    return new_state.model_copy(update={\"last_event_applied_idx\": event.idx})\n\n\ndef apply_node_download_progress(event: NodeDownloadProgress, state: State) -> State:\n    \"\"\"\n    Update or add a node download progress to state.\n    \"\"\"\n    dp = event.download_progress\n    node_id = dp.node_id\n\n    current = list(state.downloads.get(node_id, ()))\n\n    replaced = False\n    for i, existing_dp in enumerate(current):\n        # TODO(ciaran): deduplicate by model_id for now. Will need to use\n        # shard_metadata again when pipeline and tensor downloads differ.\n        # For now this is fine\n        if (\n            existing_dp.shard_metadata.model_card.model_id\n            == dp.shard_metadata.model_card.model_id\n        ):\n            current[i] = dp\n            replaced = True\n            break\n\n    if not replaced:\n        current.append(dp)\n\n    new_downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {\n        **state.downloads,\n        node_id: current,\n    }\n    return state.model_copy(update={\"downloads\": new_downloads})\n\n\ndef apply_task_created(event: TaskCreated, state: State) -> State:\n    new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: event.task}\n    return state.model_copy(update={\"tasks\": new_tasks})\n\n\ndef apply_task_deleted(event: TaskDeleted, state: State) -> State:\n    new_tasks: Mapping[TaskId, Task] = {\n        tid: task for tid, task in state.tasks.items() if tid != event.task_id\n    }\n    return state.model_copy(update={\"tasks\": new_tasks})\n\n\ndef apply_task_status_updated(event: TaskStatusUpdated, state: State) -> State:\n    if event.task_id not in state.tasks:\n        # maybe should raise\n        return state\n\n    update: dict[str, TaskStatus | None] = {\n        \"task_status\": event.task_status,\n    }\n    if event.task_status != TaskStatus.Failed:\n        update[\"error_type\"] = None\n        update[\"error_message\"] = None\n\n    updated_task = state.tasks[event.task_id].model_copy(update=update)\n    new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task}\n    return state.model_copy(update={\"tasks\": new_tasks})\n\n\ndef apply_task_failed(event: TaskFailed, state: State) -> State:\n    if event.task_id not in state.tasks:\n        # maybe should raise\n        return state\n\n    updated_task = state.tasks[event.task_id].model_copy(\n        update={\"error_type\": event.error_type, \"error_message\": event.error_message}\n    )\n    new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task}\n    return state.model_copy(update={\"tasks\": new_tasks})\n\n\ndef apply_instance_created(event: InstanceCreated, state: State) -> State:\n    instance = event.instance\n    new_instances: Mapping[InstanceId, Instance] = {\n        **state.instances,\n        instance.instance_id: instance,\n    }\n    return state.model_copy(update={\"instances\": new_instances})\n\n\ndef apply_instance_deleted(event: InstanceDeleted, state: State) -> State:\n    new_instances: Mapping[InstanceId, Instance] = {\n        iid: inst for iid, inst in state.instances.items() if iid != event.instance_id\n    }\n    return state.model_copy(update={\"instances\": new_instances})\n\n\ndef apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:\n    if isinstance(event.runner_status, RunnerShutdown):\n        new_runners: Mapping[RunnerId, RunnerStatus] = {\n            rid: rs for rid, rs in state.runners.items() if rid != event.runner_id\n        }\n        return state.model_copy(update={\"runners\": new_runners})\n    new_runners = {\n        **state.runners,\n        event.runner_id: event.runner_status,\n    }\n    return state.model_copy(update={\"runners\": new_runners})\n\n\ndef apply_node_timed_out(event: NodeTimedOut, state: State) -> State:\n    topology = copy.deepcopy(state.topology)\n    topology.remove_node(event.node_id)\n    last_seen = {\n        key: value for key, value in state.last_seen.items() if key != event.node_id\n    }\n    downloads = {\n        key: value for key, value in state.downloads.items() if key != event.node_id\n    }\n    # Clean up all granular node mappings\n    node_memory = {\n        key: value for key, value in state.node_memory.items() if key != event.node_id\n    }\n    node_disk = {\n        key: value for key, value in state.node_disk.items() if key != event.node_id\n    }\n    node_system = {\n        key: value for key, value in state.node_system.items() if key != event.node_id\n    }\n    node_network = {\n        key: value for key, value in state.node_network.items() if key != event.node_id\n    }\n    node_thunderbolt = {\n        key: value\n        for key, value in state.node_thunderbolt.items()\n        if key != event.node_id\n    }\n    node_thunderbolt_bridge = {\n        key: value\n        for key, value in state.node_thunderbolt_bridge.items()\n        if key != event.node_id\n    }\n    node_rdma_ctl = {\n        key: value for key, value in state.node_rdma_ctl.items() if key != event.node_id\n    }\n    # Only recompute cycles if the leaving node had TB bridge enabled\n    leaving_node_status = state.node_thunderbolt_bridge.get(event.node_id)\n    leaving_node_had_tb_enabled = (\n        leaving_node_status is not None and leaving_node_status.enabled\n    )\n    thunderbolt_bridge_cycles = (\n        topology.get_thunderbolt_bridge_cycles(node_thunderbolt_bridge, node_network)\n        if leaving_node_had_tb_enabled\n        else [list(cycle) for cycle in state.thunderbolt_bridge_cycles]\n    )\n    return state.model_copy(\n        update={\n            \"downloads\": downloads,\n            \"topology\": topology,\n            \"last_seen\": last_seen,\n            \"node_memory\": node_memory,\n            \"node_disk\": node_disk,\n            \"node_system\": node_system,\n            \"node_network\": node_network,\n            \"node_thunderbolt\": node_thunderbolt,\n            \"node_thunderbolt_bridge\": node_thunderbolt_bridge,\n            \"node_rdma_ctl\": node_rdma_ctl,\n            \"thunderbolt_bridge_cycles\": thunderbolt_bridge_cycles,\n        }\n    )\n\n\ndef apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:\n    topology = copy.deepcopy(state.topology)\n    topology.add_node(event.node_id)\n    info = event.info\n\n    # Build update dict with only the mappings that change\n    update: dict[str, object] = {\n        \"last_seen\": {\n            **state.last_seen,\n            event.node_id: datetime.fromisoformat(event.when),\n        },\n        \"topology\": topology,\n    }\n\n    match info:\n        case MacmonMetrics():\n            update[\"node_system\"] = {\n                **state.node_system,\n                event.node_id: info.system_profile,\n            }\n            update[\"node_memory\"] = {**state.node_memory, event.node_id: info.memory}\n        case MemoryUsage():\n            update[\"node_memory\"] = {**state.node_memory, event.node_id: info}\n        case NodeDiskUsage():\n            update[\"node_disk\"] = {**state.node_disk, event.node_id: info.disk_usage}\n        case NodeConfig():\n            pass\n        case MiscData():\n            current_identity = state.node_identities.get(event.node_id, NodeIdentity())\n            new_identity = current_identity.model_copy(\n                update={\"friendly_name\": info.friendly_name}\n            )\n            update[\"node_identities\"] = {\n                **state.node_identities,\n                event.node_id: new_identity,\n            }\n        case StaticNodeInformation():\n            current_identity = state.node_identities.get(event.node_id, NodeIdentity())\n            new_identity = current_identity.model_copy(\n                update={\n                    \"model_id\": info.model,\n                    \"chip_id\": info.chip,\n                    \"os_version\": info.os_version,\n                    \"os_build_version\": info.os_build_version,\n                }\n            )\n            update[\"node_identities\"] = {\n                **state.node_identities,\n                event.node_id: new_identity,\n            }\n        case NodeNetworkInterfaces():\n            update[\"node_network\"] = {\n                **state.node_network,\n                event.node_id: NodeNetworkInfo(interfaces=info.ifaces),\n            }\n        case MacThunderboltIdentifiers():\n            update[\"node_thunderbolt\"] = {\n                **state.node_thunderbolt,\n                event.node_id: NodeThunderboltInfo(interfaces=info.idents),\n            }\n        case MacThunderboltConnections():\n            conn_map = {\n                tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)\n                for nid in state.node_thunderbolt\n                for tb_ident in state.node_thunderbolt[nid].interfaces\n            }\n            as_rdma_conns = [\n                Connection(\n                    source=event.node_id,\n                    sink=conn_map[tb_conn.sink_uuid][0],\n                    edge=RDMAConnection(\n                        source_rdma_iface=conn_map[tb_conn.source_uuid][1],\n                        sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],\n                    ),\n                )\n                for tb_conn in info.conns\n                if tb_conn.source_uuid in conn_map\n                if tb_conn.sink_uuid in conn_map\n            ]\n            topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)\n        case ThunderboltBridgeInfo():\n            new_tb_bridge: dict[NodeId, ThunderboltBridgeStatus] = {\n                **state.node_thunderbolt_bridge,\n                event.node_id: info.status,\n            }\n            update[\"node_thunderbolt_bridge\"] = new_tb_bridge\n            # Only recompute cycles if the enabled status changed\n            old_status = state.node_thunderbolt_bridge.get(event.node_id)\n            old_enabled = old_status.enabled if old_status else False\n            new_enabled = info.status.enabled\n            if old_enabled != new_enabled:\n                update[\"thunderbolt_bridge_cycles\"] = (\n                    topology.get_thunderbolt_bridge_cycles(\n                        new_tb_bridge, state.node_network\n                    )\n                )\n        case RdmaCtlStatus():\n            update[\"node_rdma_ctl\"] = {\n                **state.node_rdma_ctl,\n                event.node_id: NodeRdmaCtlStatus(enabled=info.enabled),\n            }\n\n    return state.model_copy(update=update)\n\n\ndef apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:\n    topology = copy.deepcopy(state.topology)\n    topology.add_connection(event.conn)\n    return state.model_copy(update={\"topology\": topology})\n\n\ndef apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:\n    topology = copy.deepcopy(state.topology)\n    topology.remove_connection(event.conn)\n    # TODO: Clean up removing the reverse connection\n    return state.model_copy(update={\"topology\": topology})\n"
  },
  {
    "path": "src/exo/shared/constants.py",
    "content": "import os\nimport sys\nfrom pathlib import Path\n\nfrom exo.utils.dashboard_path import find_dashboard, find_resources\n\n_EXO_HOME_ENV = os.environ.get(\"EXO_HOME\", None)\n\n\ndef _get_xdg_dir(env_var: str, fallback: str) -> Path:\n    \"\"\"Get XDG directory, prioritising EXO_HOME environment variable if its set. On non-Linux platforms, default to ~/.exo.\"\"\"\n\n    if _EXO_HOME_ENV is not None:\n        return Path.home() / _EXO_HOME_ENV\n\n    if sys.platform != \"linux\":\n        return Path.home() / \".exo\"\n\n    xdg_value = os.environ.get(env_var, None)\n    if xdg_value is not None:\n        return Path(xdg_value) / \"exo\"\n    return Path.home() / fallback / \"exo\"\n\n\nEXO_CONFIG_HOME = _get_xdg_dir(\"XDG_CONFIG_HOME\", \".config\")\nEXO_DATA_HOME = _get_xdg_dir(\"XDG_DATA_HOME\", \".local/share\")\nEXO_CACHE_HOME = _get_xdg_dir(\"XDG_CACHE_HOME\", \".cache\")\n\n# Models directory (data)\n_EXO_MODELS_DIR_ENV = os.environ.get(\"EXO_MODELS_DIR\", None)\nEXO_MODELS_DIR = (\n    EXO_DATA_HOME / \"models\"\n    if _EXO_MODELS_DIR_ENV is None\n    else Path.home() / _EXO_MODELS_DIR_ENV\n)\n\n# Read-only search path for pre-downloaded models (colon-separated directories)\n_EXO_MODELS_PATH_ENV = os.environ.get(\"EXO_MODELS_PATH\", None)\nEXO_MODELS_PATH: tuple[Path, ...] | None = (\n    tuple(Path(p).expanduser() for p in _EXO_MODELS_PATH_ENV.split(\":\") if p)\n    if _EXO_MODELS_PATH_ENV is not None\n    else None\n)\n\n_RESOURCES_DIR_ENV = os.environ.get(\"EXO_RESOURCES_DIR\", None)\nRESOURCES_DIR = (\n    find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV\n)\n_DASHBOARD_DIR_ENV = os.environ.get(\"EXO_DASHBOARD_DIR\", None)\nDASHBOARD_DIR = (\n    find_dashboard() if _DASHBOARD_DIR_ENV is None else Path.home() / _DASHBOARD_DIR_ENV\n)\n\n# Log files (data/logs or cache)\nEXO_LOG_DIR = EXO_CACHE_HOME / \"exo_log\"\nEXO_LOG = EXO_LOG_DIR / \"exo.log\"\nEXO_TEST_LOG = EXO_CACHE_HOME / \"exo_test.log\"\n\n# Identity (config)\nEXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / \"node_id.keypair\"\nEXO_CONFIG_FILE = EXO_CONFIG_HOME / \"config.toml\"\n\n# libp2p topics for event forwarding\nLIBP2P_LOCAL_EVENTS_TOPIC = \"worker_events\"\nLIBP2P_GLOBAL_EVENTS_TOPIC = \"global_events\"\nLIBP2P_ELECTION_MESSAGES_TOPIC = \"election_message\"\nLIBP2P_COMMANDS_TOPIC = \"commands\"\n\nEXO_MAX_CHUNK_SIZE = 512 * 1024\n\nEXO_CUSTOM_MODEL_CARDS_DIR = EXO_DATA_HOME / \"custom_model_cards\"\n\nEXO_EVENT_LOG_DIR = EXO_DATA_HOME / \"event_log\"\nEXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / \"images\"\nEXO_TRACING_CACHE_DIR = EXO_CACHE_HOME / \"traces\"\n\nEXO_ENABLE_IMAGE_MODELS = (\n    os.getenv(\"EXO_ENABLE_IMAGE_MODELS\", \"false\").lower() == \"true\"\n)\n\nEXO_OFFLINE = os.getenv(\"EXO_OFFLINE\", \"false\").lower() == \"true\"\n\nEXO_TRACING_ENABLED = os.getenv(\"EXO_TRACING_ENABLED\", \"false\").lower() == \"true\"\n\nEXO_MAX_CONCURRENT_REQUESTS = int(os.getenv(\"EXO_MAX_CONCURRENT_REQUESTS\", \"8\"))\n"
  },
  {
    "path": "src/exo/shared/election.py",
    "content": "from typing import Self\n\nimport anyio\nfrom anyio import (\n    CancelScope,\n    Event,\n    get_cancelled_exc_class,\n)\nfrom loguru import logger\n\nfrom exo.routing.connection_message import ConnectionMessage\nfrom exo.shared.types.commands import ForwarderCommand\nfrom exo.shared.types.common import NodeId, SessionId\nfrom exo.utils.channels import Receiver, Sender\nfrom exo.utils.pydantic_ext import CamelCaseModel\nfrom exo.utils.task_group import TaskGroup\n\nDEFAULT_ELECTION_TIMEOUT = 3.0\n\n\nclass ElectionMessage(CamelCaseModel):\n    clock: int\n    seniority: int\n    proposed_session: SessionId\n    commands_seen: int\n\n    # Could eventually include a list of neighbour nodes for centrality\n    def __lt__(self, other: Self) -> bool:\n        if self.clock != other.clock:\n            return self.clock < other.clock\n        if self.seniority != other.seniority:\n            return self.seniority < other.seniority\n        elif self.commands_seen != other.commands_seen:\n            return self.commands_seen < other.commands_seen\n        else:\n            return (\n                self.proposed_session.master_node_id\n                < other.proposed_session.master_node_id\n            )\n\n\nclass ElectionResult(CamelCaseModel):\n    session_id: SessionId\n    won_clock: int\n    is_new_master: bool\n\n\nclass Election:\n    def __init__(\n        self,\n        node_id: NodeId,\n        *,\n        election_message_receiver: Receiver[ElectionMessage],\n        election_message_sender: Sender[ElectionMessage],\n        election_result_sender: Sender[ElectionResult],\n        connection_message_receiver: Receiver[ConnectionMessage],\n        command_receiver: Receiver[ForwarderCommand],\n        is_candidate: bool = True,\n        seniority: int = 0,\n    ):\n        # If we aren't a candidate, simply don't increment seniority.\n        # For reference: This node can be elected master if all nodes are not master candidates\n        # Any master candidate will automatically win out over this node.\n        self.seniority = seniority if is_candidate else -1\n        self.clock = 0\n        self.node_id = node_id\n        self.commands_seen = 0\n        # Every node spawns as master\n        self.current_session: SessionId = SessionId(\n            master_node_id=node_id, election_clock=0\n        )\n\n        # Senders/Receivers\n        self._em_sender = election_message_sender\n        self._em_receiver = election_message_receiver\n        self._er_sender = election_result_sender\n        self._cm_receiver = connection_message_receiver\n        self._co_receiver = command_receiver\n\n        # Campaign state\n        self._candidates: list[ElectionMessage] = []\n        self._campaign_cancel_scope: CancelScope | None = None\n        self._campaign_done: Event | None = None\n        self._tg = TaskGroup()\n\n    async def run(self):\n        logger.info(\"Starting Election\")\n        try:\n            async with self._tg as tg:\n                tg.start_soon(self._election_receiver)\n                tg.start_soon(self._connection_receiver)\n                tg.start_soon(self._command_counter)\n\n                # And start an election immediately, that instantly resolves\n                candidates: list[ElectionMessage] = []\n                logger.debug(\"Starting initial campaign\")\n                self._candidates = candidates\n                await self._campaign(candidates, campaign_timeout=0.0)\n                logger.debug(\"Initial campaign finished\")\n        finally:\n            # Cancel and wait for the last election to end\n            if self._campaign_cancel_scope is not None:\n                logger.debug(\"Cancelling campaign\")\n                self._campaign_cancel_scope.cancel()\n            if self._campaign_done is not None:\n                logger.debug(\"Waiting for campaign to finish\")\n                await self._campaign_done.wait()\n            logger.debug(\"Campaign cancelled and finished\")\n            logger.info(\"Election shutdown\")\n\n    async def elect(self, em: ElectionMessage) -> None:\n        logger.debug(f\"Electing: {em}\")\n        is_new_master = em.proposed_session != self.current_session\n        self.current_session = em.proposed_session\n        logger.debug(f\"Current session: {self.current_session}\")\n        await self._er_sender.send(\n            ElectionResult(\n                won_clock=em.clock,\n                session_id=em.proposed_session,\n                is_new_master=is_new_master,\n            )\n        )\n\n    async def shutdown(self) -> None:\n        self._tg.cancel_tasks()\n\n    async def _election_receiver(self) -> None:\n        with self._em_receiver as election_messages:\n            async for message in election_messages:\n                logger.debug(f\"Election message received: {message}\")\n                if message.proposed_session.master_node_id == self.node_id:\n                    logger.debug(\"Dropping message from ourselves\")\n                    # Drop messages from us (See exo.routing.router)\n                    continue\n                # If a new round is starting, we participate\n                if message.clock > self.clock:\n                    self.clock = message.clock\n                    logger.debug(f\"New clock: {self.clock}\")\n                    logger.debug(\"Starting new campaign\")\n                    candidates: list[ElectionMessage] = [message]\n                    logger.debug(f\"Candidates: {candidates}\")\n                    logger.debug(f\"Current candidates: {self._candidates}\")\n                    self._candidates = candidates\n                    logger.debug(f\"New candidates: {self._candidates}\")\n                    logger.debug(\"Starting new campaign\")\n                    self._tg.start_soon(\n                        self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT\n                    )\n                    logger.debug(\"Campaign started\")\n                    continue\n                # Dismiss old messages\n                if message.clock < self.clock:\n                    logger.debug(f\"Dropping old message: {message}\")\n                    continue\n                logger.debug(f\"Election added candidate {message}\")\n                # Now we are processing this rounds messages - including the message that triggered this round.\n                self._candidates.append(message)\n\n    async def _connection_receiver(self) -> None:\n        with self._cm_receiver as connection_messages:\n            async for first in connection_messages:\n                # Delay after connection message for time to symmetrically setup\n                await anyio.sleep(0.2)\n                rest = connection_messages.collect()\n\n                logger.debug(\n                    f\"Connection messages received: {first} followed by {rest}\"\n                )\n                logger.debug(f\"Current clock: {self.clock}\")\n                # These messages are strictly peer to peer\n                self.clock += 1\n                logger.debug(f\"New clock: {self.clock}\")\n                candidates: list[ElectionMessage] = []\n                self._candidates = candidates\n                logger.debug(\"Starting new campaign\")\n                self._tg.start_soon(\n                    self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT\n                )\n                logger.debug(\"Campaign started\")\n                logger.debug(\"Connection message added\")\n\n    async def _command_counter(self) -> None:\n        with self._co_receiver as commands:\n            async for _command in commands:\n                self.commands_seen += 1\n\n    async def _campaign(\n        self, candidates: list[ElectionMessage], campaign_timeout: float\n    ) -> None:\n        clock = self.clock\n\n        # Kill the old campaign\n        if self._campaign_cancel_scope:\n            logger.info(\"Cancelling other campaign\")\n            self._campaign_cancel_scope.cancel()\n        if self._campaign_done:\n            logger.info(\"Waiting for other campaign to finish\")\n            await self._campaign_done.wait()\n\n        done = Event()\n        self._campaign_done = done\n        scope = CancelScope()\n        self._campaign_cancel_scope = scope\n\n        try:\n            with scope:\n                logger.debug(f\"Election {clock} started\")\n\n                status = self._election_status(clock)\n                candidates.append(status)\n                await self._em_sender.send(status)\n\n                logger.debug(f\"Sleeping for {campaign_timeout} seconds\")\n                await anyio.sleep(campaign_timeout)\n                # minor hack - rebroadcast status in case anyone has missed it.\n                await self._em_sender.send(status)\n                logger.debug(\"Woke up from sleep\")\n                # add an anyio checkpoint - anyio.lowlevel.chekpoint() or checkpoint_if_cancelled() is preferred, but wasn't typechecking last I checked\n                await anyio.sleep(0)\n\n                # Election finished!\n                elected = max(candidates)\n                logger.debug(f\"Election queue {candidates}\")\n                logger.debug(f\"Elected: {elected}\")\n                if (\n                    self.node_id == elected.proposed_session.master_node_id\n                    and self.seniority >= 0\n                ):\n                    logger.debug(\n                        f\"Node is a candidate and seniority is {self.seniority}\"\n                    )\n                    self.seniority = max(self.seniority, len(candidates))\n                    logger.debug(f\"New seniority: {self.seniority}\")\n                else:\n                    logger.debug(\n                        f\"Node is not a candidate or seniority is not {self.seniority}\"\n                    )\n                logger.debug(\n                    f\"Election finished, new SessionId({elected.proposed_session}) with queue {candidates}\"\n                )\n                logger.debug(\"Sending election result\")\n                await self.elect(elected)\n                logger.debug(\"Election result sent\")\n        except get_cancelled_exc_class():\n            logger.debug(f\"Election {clock} cancelled\")\n        finally:\n            logger.debug(f\"Election {clock} finally\")\n            if self._campaign_cancel_scope is scope:\n                self._campaign_cancel_scope = None\n            logger.debug(\"Setting done event\")\n            done.set()\n            logger.debug(\"Done event set\")\n\n    def _election_status(self, clock: int | None = None) -> ElectionMessage:\n        c = self.clock if clock is None else clock\n        return ElectionMessage(\n            proposed_session=(\n                self.current_session\n                if self.current_session.master_node_id == self.node_id\n                else SessionId(master_node_id=self.node_id, election_clock=c)\n            ),\n            clock=c,\n            seniority=self.seniority,\n            commands_seen=self.commands_seen,\n        )\n"
  },
  {
    "path": "src/exo/shared/logging.py",
    "content": "import logging\nimport sys\nfrom collections.abc import Iterator\nfrom pathlib import Path\n\nimport zstandard\nfrom hypercorn import Config\nfrom hypercorn.logging import Logger as HypercornLogger\nfrom loguru import logger\n\n_MAX_LOG_ARCHIVES = 5\n\n\ndef _zstd_compress(filepath: str) -> None:\n    source = Path(filepath)\n    dest = source.with_suffix(source.suffix + \".zst\")\n    cctx = zstandard.ZstdCompressor()\n    with open(source, \"rb\") as f_in, open(dest, \"wb\") as f_out:\n        cctx.copy_stream(f_in, f_out)\n    source.unlink()\n\n\ndef _once_then_never() -> Iterator[bool]:\n    yield True\n    while True:\n        yield False\n\n\nclass InterceptLogger(HypercornLogger):\n    def __init__(self, config: Config):\n        super().__init__(config)\n        assert self.error_logger\n        self.error_logger.handlers = [_InterceptHandler()]\n\n\nclass _InterceptHandler(logging.Handler):\n    def emit(self, record: logging.LogRecord):\n        try:\n            level = logger.level(record.levelname).name\n        except ValueError:\n            level = record.levelno\n\n        logger.opt(depth=3, exception=record.exc_info).log(level, record.getMessage())\n\n\ndef logger_setup(log_file: Path | None, verbosity: int = 0):\n    \"\"\"Set up logging for this process - formatting, file handles, verbosity and output\"\"\"\n\n    logging.getLogger(\"exo_pyo3_bindings\").setLevel(logging.WARNING)\n    logging.getLogger(\"httpx\").setLevel(logging.WARNING)\n    logging.getLogger(\"httpcore\").setLevel(logging.WARNING)\n\n    logger.remove()\n\n    # replace all stdlib loggers with _InterceptHandlers that log to loguru\n    logging.basicConfig(handlers=[_InterceptHandler()], level=0)\n\n    if verbosity == 0:\n        logger.add(\n            sys.__stderr__,  # type: ignore\n            format=\"[ {time:hh:mm:ss.SSSSA} | <level>{level: <8}</level>] <level>{message}</level>\",\n            level=\"INFO\",\n            colorize=True,\n            enqueue=True,\n        )\n    else:\n        logger.add(\n            sys.__stderr__,  # type: ignore\n            format=\"[ {time:HH:mm:ss.SSS} | <level>{level: <8}</level> | {name}:{function}:{line} ] <level>{message}</level>\",\n            level=\"DEBUG\",\n            colorize=True,\n            enqueue=True,\n        )\n    if log_file:\n        rotate_once = _once_then_never()\n        logger.add(\n            log_file,\n            format=\"[ {time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} ] {message}\",\n            level=\"INFO\",\n            colorize=False,\n            enqueue=True,\n            rotation=lambda _, __: next(rotate_once),\n            retention=_MAX_LOG_ARCHIVES,\n            compression=_zstd_compress,\n        )\n\n\ndef logger_cleanup():\n    \"\"\"Flush all queues before shutting down so any in-flight logs are written to disk\"\"\"\n    logger.complete()\n\n\n\"\"\" --- TODO: Capture MLX Log output:\nimport contextlib\nimport sys\nfrom loguru import logger\n\nclass StreamToLogger:\n\n    def __init__(self, level=\"INFO\"):\n        self._level = level\n\n    def write(self, buffer):\n        for line in buffer.rstrip().splitlines():\n            logger.opt(depth=1).log(self._level, line.rstrip())\n\n    def flush(self):\n        pass\n\nlogger.remove()\nlogger.add(sys.__stdout__)\n\nstream = StreamToLogger()\nwith contextlib.redirect_stdout(stream):\n    print(\"Standard output is sent to added handlers.\")\n\"\"\"\n"
  },
  {
    "path": "src/exo/shared/models/model_cards.py",
    "content": "from enum import Enum\nfrom typing import Annotated, Any\n\nimport aiofiles\nimport aiofiles.os as aios\nimport tomlkit\nfrom anyio import Path, open_file\nfrom huggingface_hub import model_info\nfrom loguru import logger\nfrom pydantic import (\n    AliasChoices,\n    BaseModel,\n    Field,\n    PositiveInt,\n    ValidationError,\n    field_validator,\n    model_validator,\n)\nfrom tomlkit.exceptions import TOMLKitError\n\nfrom exo.shared.constants import (\n    EXO_CUSTOM_MODEL_CARDS_DIR,\n    EXO_ENABLE_IMAGE_MODELS,\n    RESOURCES_DIR,\n)\nfrom exo.shared.types.common import ModelId\nfrom exo.shared.types.memory import Memory\nfrom exo.utils.pydantic_ext import CamelCaseModel\n\n# kinda ugly...\n# TODO: load search path from config.toml\n_custom_cards_dir = Path(str(EXO_CUSTOM_MODEL_CARDS_DIR))\nCARD_SEARCH_PATH = [\n    Path(RESOURCES_DIR) / \"inference_model_cards\",\n    Path(RESOURCES_DIR) / \"image_model_cards\",\n    _custom_cards_dir,\n]\n\n_card_cache: dict[ModelId, \"ModelCard\"] = {}\n\n\nasync def _refresh_card_cache():\n    for path in CARD_SEARCH_PATH:\n        async for toml_file in path.rglob(\"*.toml\"):\n            try:\n                card = await ModelCard.load_from_path(toml_file)\n                if card.model_id not in _card_cache:\n                    _card_cache[card.model_id] = card\n            except (ValidationError, TOMLKitError):\n                pass\n\n\ndef _is_image_card(card: \"ModelCard\") -> bool:\n    return any(t in (ModelTask.TextToImage, ModelTask.ImageToImage) for t in card.tasks)\n\n\nasync def get_model_cards() -> list[\"ModelCard\"]:\n    if len(_card_cache) == 0:\n        await _refresh_card_cache()\n    if EXO_ENABLE_IMAGE_MODELS:\n        return list(_card_cache.values())\n    return [c for c in _card_cache.values() if not _is_image_card(c)]\n\n\nclass ModelTask(str, Enum):\n    TextGeneration = \"TextGeneration\"\n    TextToImage = \"TextToImage\"\n    ImageToImage = \"ImageToImage\"\n\n\nclass ComponentInfo(CamelCaseModel):\n    component_name: str\n    component_path: str\n    storage_size: Memory\n    n_layers: PositiveInt | None = None\n    can_shard: bool\n    safetensors_index_filename: str | None = None\n\n\nclass ModelCard(CamelCaseModel):\n    model_id: ModelId\n    storage_size: Memory\n    n_layers: PositiveInt\n    hidden_size: PositiveInt\n    supports_tensor: bool\n    num_key_value_heads: PositiveInt | None = None\n    tasks: list[ModelTask]\n    components: list[ComponentInfo] | None = None\n    family: str = \"\"\n    quantization: str = \"\"\n    base_model: str = \"\"\n    capabilities: list[str] = []\n    uses_cfg: bool = False\n    trust_remote_code: bool = True\n\n    @field_validator(\"tasks\", mode=\"before\")\n    @classmethod\n    def _validate_tasks(cls, v: list[str | ModelTask]) -> list[ModelTask]:\n        return [item if isinstance(item, ModelTask) else ModelTask(item) for item in v]\n\n    async def save(self, path: Path) -> None:\n        async with await open_file(path, \"w\") as f:\n            py = self.model_dump(exclude_none=True)\n            data = tomlkit.dumps(py)  # pyright: ignore[reportUnknownMemberType]\n            await f.write(data)\n\n    async def save_to_custom_dir(self) -> None:\n        await aios.makedirs(str(_custom_cards_dir), exist_ok=True)\n        await self.save(_custom_cards_dir / (self.model_id.normalize() + \".toml\"))\n\n    @staticmethod\n    async def load_from_path(path: Path) -> \"ModelCard\":\n        async with await open_file(path, \"r\") as f:\n            py = tomlkit.loads(await f.read())\n            return ModelCard.model_validate(py)\n\n    # Is it okay that model card.load defaults to network access if the card doesn't exist? do we want to be more explicit here?\n    @staticmethod\n    async def load(model_id: ModelId) -> \"ModelCard\":\n        if model_id not in _card_cache:\n            await _refresh_card_cache()\n        if (mc := _card_cache.get(model_id)) is not None:\n            return mc\n\n        return await ModelCard.fetch_from_hf(model_id)\n\n    @staticmethod\n    async def fetch_from_hf(model_id: ModelId) -> \"ModelCard\":\n        \"\"\"Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta.\"\"\"\n        # TODO: failure if files do not exist\n        config_data = await fetch_config_data(model_id)\n        num_layers = config_data.layer_count\n        mem_size_bytes = await fetch_safetensors_size(model_id)\n\n        mc = ModelCard(\n            model_id=ModelId(model_id),\n            storage_size=mem_size_bytes,\n            n_layers=num_layers,\n            hidden_size=config_data.hidden_size or 0,\n            supports_tensor=config_data.supports_tensor,\n            num_key_value_heads=config_data.num_key_value_heads,\n            tasks=[ModelTask.TextGeneration],\n            trust_remote_code=False,\n        )\n        await mc.save_to_custom_dir()\n        _card_cache[model_id] = mc\n        return mc\n\n\nasync def delete_custom_card(model_id: ModelId) -> bool:\n    \"\"\"Delete a user-added custom model card. Returns True if deleted.\"\"\"\n    card_path = _custom_cards_dir / (ModelId(model_id).normalize() + \".toml\")\n    if await card_path.exists():\n        await card_path.unlink()\n        _card_cache.pop(model_id, None)\n        return True\n    return False\n\n\ndef is_custom_card(model_id: ModelId) -> bool:\n    \"\"\"Check if a model card exists in the custom cards directory.\"\"\"\n    import os\n\n    card_path = Path(str(EXO_CUSTOM_MODEL_CARDS_DIR)) / (\n        ModelId(model_id).normalize() + \".toml\"\n    )\n    return os.path.isfile(str(card_path))\n\n\nclass ConfigData(BaseModel):\n    model_config = {\"extra\": \"ignore\"}  # Allow unknown fields\n\n    architectures: list[str] | None = None\n    hidden_size: Annotated[int, Field(ge=0)] | None = None\n    num_key_value_heads: PositiveInt | None = None\n    layer_count: int = Field(\n        validation_alias=AliasChoices(\n            \"num_hidden_layers\",\n            \"num_layers\",\n            \"n_layer\",\n            \"n_layers\",\n            \"num_decoder_layers\",\n            \"decoder_layers\",\n        )\n    )\n\n    @property\n    def supports_tensor(self) -> bool:\n        return self.architectures in [\n            [\"Glm4MoeLiteForCausalLM\"],\n            [\"GlmMoeDsaForCausalLM\"],\n            [\"DeepseekV32ForCausalLM\"],\n            [\"DeepseekV3ForCausalLM\"],\n            [\"Qwen3NextForCausalLM\"],\n            [\"Qwen3MoeForCausalLM\"],\n            [\"Qwen3_5MoeForConditionalGeneration\"],\n            [\"Qwen3_5ForConditionalGeneration\"],\n            [\"MiniMaxM2ForCausalLM\"],\n            [\"LlamaForCausalLM\"],\n            [\"GptOssForCausalLM\"],\n            [\"Step3p5ForCausalLM\"],\n            [\"NemotronHForCausalLM\"],\n        ]\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def defer_to_text_config(cls, data: dict[str, Any]):\n        text_config = data.get(\"text_config\")\n        if text_config is None:\n            return data\n\n        for field in [\n            \"architectures\",\n            \"hidden_size\",\n            \"num_key_value_heads\",\n            \"num_hidden_layers\",\n            \"num_layers\",\n            \"n_layer\",\n            \"n_layers\",\n            \"num_decoder_layers\",\n            \"decoder_layers\",\n        ]:\n            if (val := text_config.get(field)) is not None:  # pyright: ignore[reportAny]\n                data[field] = val\n\n        return data\n\n\nasync def fetch_config_data(model_id: ModelId) -> ConfigData:\n    \"\"\"Downloads and parses config.json for a model.\"\"\"\n    from exo.download.download_utils import (\n        download_file_with_retry,\n        ensure_models_dir,\n    )\n\n    target_dir = (await ensure_models_dir()) / model_id.normalize()\n    await aios.makedirs(target_dir, exist_ok=True)\n    config_path = await download_file_with_retry(\n        model_id,\n        \"main\",\n        \"config.json\",\n        target_dir,\n        lambda curr_bytes, total_bytes, is_renamed: logger.debug(\n            f\"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})\"\n        ),\n    )\n    async with aiofiles.open(config_path, \"r\") as f:\n        return ConfigData.model_validate_json(await f.read())\n\n\nasync def fetch_safetensors_size(model_id: ModelId) -> Memory:\n    \"\"\"Gets model size from safetensors index or falls back to HF API.\"\"\"\n    from exo.download.download_utils import (\n        download_file_with_retry,\n        ensure_models_dir,\n    )\n    from exo.shared.types.worker.downloads import ModelSafetensorsIndex\n\n    target_dir = (await ensure_models_dir()) / model_id.normalize()\n    await aios.makedirs(target_dir, exist_ok=True)\n    index_path = await download_file_with_retry(\n        model_id,\n        \"main\",\n        \"model.safetensors.index.json\",\n        target_dir,\n        lambda curr_bytes, total_bytes, is_renamed: logger.debug(\n            f\"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})\"\n        ),\n    )\n    async with aiofiles.open(index_path, \"r\") as f:\n        index_data = ModelSafetensorsIndex.model_validate_json(await f.read())\n\n    metadata = index_data.metadata\n    if metadata is not None:\n        return Memory.from_bytes(metadata.total_size)\n\n    info = model_info(model_id)\n    if info.safetensors is None:\n        raise ValueError(f\"No safetensors info found for {model_id}\")\n    return Memory.from_bytes(info.safetensors.total)\n"
  },
  {
    "path": "src/exo/shared/tests/__init__.py",
    "content": "# Test package for shared utilities\n"
  },
  {
    "path": "src/exo/shared/tests/conftest.py",
    "content": "\"\"\"Pytest configuration and shared fixtures for shared package tests.\"\"\"\n\nimport asyncio\nfrom typing import Generator\n\nimport pytest\nfrom _pytest.logging import LogCaptureFixture\nfrom loguru import logger\n\nfrom exo.shared.models.model_cards import ModelCard, ModelId, ModelTask\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata\n\n\n@pytest.fixture(scope=\"session\")\ndef event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:\n    \"\"\"Create an event loop for the test session.\"\"\"\n    loop = asyncio.new_event_loop()\n    asyncio.set_event_loop(loop)\n    yield loop\n    loop.close()\n\n\n@pytest.fixture(autouse=True)\ndef reset_event_loop():\n    \"\"\"Reset the event loop for each test to ensure clean state.\"\"\"\n    # This ensures each test gets a fresh event loop state\n\n\ndef get_pipeline_shard_metadata(\n    model_id: ModelId, device_rank: int, world_size: int = 1\n) -> ShardMetadata:\n    return PipelineShardMetadata(\n        model_card=ModelCard(\n            model_id=model_id,\n            storage_size=Memory.from_mb(100000),\n            n_layers=32,\n            hidden_size=1000,\n            supports_tensor=True,\n            tasks=[ModelTask.TextGeneration],\n        ),\n        device_rank=device_rank,\n        world_size=world_size,\n        start_layer=0,\n        end_layer=32,\n        n_layers=32,\n    )\n\n\n@pytest.fixture\ndef caplog(caplog: LogCaptureFixture):\n    handler_id = logger.add(\n        caplog.handler,\n        format=\"{message}\",\n        level=0,\n        filter=lambda record: record[\"level\"].no >= caplog.handler.level,\n        enqueue=True,  # Set to 'True' if your test is spawning child processes.\n    )\n    yield caplog\n    logger.remove(handler_id)\n"
  },
  {
    "path": "src/exo/shared/tests/test_apply/test_apply_node_download.py",
    "content": "from exo.shared.apply import apply_node_download_progress\nfrom exo.shared.tests.conftest import get_pipeline_shard_metadata\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.events import NodeDownloadProgress\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.state import State\nfrom exo.shared.types.worker.downloads import DownloadCompleted\nfrom exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID\n\n\ndef test_apply_node_download_progress():\n    state = State()\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    event = DownloadCompleted(\n        node_id=NodeId(\"node-1\"),\n        shard_metadata=shard1,\n        total=Memory(),\n    )\n\n    new_state = apply_node_download_progress(\n        NodeDownloadProgress(download_progress=event), state\n    )\n\n    assert new_state.downloads == {NodeId(\"node-1\"): [event]}\n\n\ndef test_apply_two_node_download_progress():\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard2 = get_pipeline_shard_metadata(MODEL_B_ID, device_rank=0, world_size=2)\n    event1 = DownloadCompleted(\n        node_id=NodeId(\"node-1\"),\n        shard_metadata=shard1,\n        total=Memory(),\n    )\n    event2 = DownloadCompleted(\n        node_id=NodeId(\"node-1\"),\n        shard_metadata=shard2,\n        total=Memory(),\n    )\n    state = State(downloads={NodeId(\"node-1\"): [event1]})\n\n    new_state = apply_node_download_progress(\n        NodeDownloadProgress(download_progress=event2), state\n    )\n\n    assert new_state.downloads == {NodeId(\"node-1\"): [event1, event2]}\n"
  },
  {
    "path": "src/exo/shared/tests/test_apply/test_apply_runner_deleted.py",
    "content": "from exo.shared.apply import apply_runner_status_updated\nfrom exo.shared.types.events import RunnerStatusUpdated\nfrom exo.shared.types.state import State\nfrom exo.shared.types.worker.runners import RunnerId, RunnerIdle, RunnerShutdown\n\n\ndef test_apply_runner_shutdown_removes_runner():\n    runner_id = RunnerId()\n    state = State(runners={runner_id: RunnerIdle()})\n\n    new_state = apply_runner_status_updated(\n        RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown()), state\n    )\n\n    assert runner_id not in new_state.runners\n\n\ndef test_apply_runner_status_updated_adds_runner():\n    runner_id = RunnerId()\n    state = State()\n\n    new_state = apply_runner_status_updated(\n        RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerIdle()), state\n    )\n\n    assert runner_id in new_state.runners\n"
  },
  {
    "path": "src/exo/shared/tests/test_election.py",
    "content": "import pytest\nfrom anyio import create_task_group, fail_after, move_on_after\n\nfrom exo.routing.connection_message import ConnectionMessage\nfrom exo.shared.election import Election, ElectionMessage, ElectionResult\nfrom exo.shared.types.commands import ForwarderCommand, TestCommand\nfrom exo.shared.types.common import NodeId, SessionId, SystemId\nfrom exo.utils.channels import channel\n\n# ======= #\n# Helpers #\n# ======= #\n\n\ndef em(\n    clock: int,\n    seniority: int,\n    node_id: str,\n    commands_seen: int = 0,\n    election_clock: int | None = None,\n) -> ElectionMessage:\n    \"\"\"\n    Helper to build ElectionMessages for a given proposer node.\n\n    The new API carries a proposed SessionId (master_node_id + election_clock).\n    By default we use the same value for election_clock as the 'clock' of the round.\n    \"\"\"\n    return ElectionMessage(\n        clock=clock,\n        seniority=seniority,\n        proposed_session=SessionId(\n            master_node_id=NodeId(node_id),\n            election_clock=clock if election_clock is None else election_clock,\n        ),\n        commands_seen=commands_seen,\n    )\n\n\n# ======================================= #\n#                 TESTS                   #\n# ======================================= #\n\n\n@pytest.fixture(autouse=True)\ndef fast_election_timeout(monkeypatch: pytest.MonkeyPatch):\n    monkeypatch.setattr(\"exo.shared.election.DEFAULT_ELECTION_TIMEOUT\", 0.1)\n\n\n@pytest.mark.anyio\nasync def test_single_round_broadcasts_and_updates_seniority_on_self_win() -> None:\n    \"\"\"\n    Start a round by injecting an ElectionMessage with higher clock.\n    With only our node effectively 'winning', we should broadcast once and update seniority.\n    \"\"\"\n    # Outbound election messages from the Election (we'll observe these)\n    em_out_tx, em_out_rx = channel[ElectionMessage]()\n    # Inbound election messages to the Election (we'll inject these)\n    em_in_tx, em_in_rx = channel[ElectionMessage]()\n    # Election results produced by the Election (we'll observe these)\n    er_tx, er_rx = channel[ElectionResult]()\n    # Connection messages\n    cm_tx, cm_rx = channel[ConnectionMessage]()\n    # Commands\n    co_tx, co_rx = channel[ForwarderCommand]()\n\n    election = Election(\n        node_id=NodeId(\"B\"),\n        election_message_receiver=em_in_rx,\n        election_message_sender=em_out_tx,\n        election_result_sender=er_tx,\n        connection_message_receiver=cm_rx,\n        command_receiver=co_rx,\n        is_candidate=True,\n    )\n\n    async with create_task_group() as tg:\n        with fail_after(2):\n            tg.start_soon(election.run)\n            # Trigger new round at clock=1 (peer announces it)\n            await em_in_tx.send(em(clock=1, seniority=0, node_id=\"A\"))\n\n            # Expect our broadcast back to the peer side for this round only\n            while True:\n                got = await em_out_rx.receive()\n                if got.clock == 1 and got.proposed_session.master_node_id == NodeId(\n                    \"B\"\n                ):\n                    break\n\n            # Wait for the round to finish and produce an ElectionResult\n            result = await er_rx.receive()\n            assert result.session_id.master_node_id == NodeId(\"B\")\n            # We spawned as master; electing ourselves again is not \"new master\".\n            assert result.is_new_master is False\n\n            # Close inbound streams to end the receivers (and run())\n            em_in_tx.close()\n            cm_tx.close()\n            co_tx.close()\n\n    # We should have updated seniority to 2 (A + B).\n    assert election.seniority == 2\n\n\n@pytest.mark.anyio\nasync def test_peer_with_higher_seniority_wins_and_we_switch_master() -> None:\n    \"\"\"\n    If a peer with clearly higher seniority participates in the round, they should win.\n    We should broadcast our status exactly once for this round, then switch master.\n    \"\"\"\n    em_out_tx, em_out_rx = channel[ElectionMessage]()\n    em_in_tx, em_in_rx = channel[ElectionMessage]()\n    er_tx, er_rx = channel[ElectionResult]()\n    cm_tx, cm_rx = channel[ConnectionMessage]()\n    co_tx, co_rx = channel[ForwarderCommand]()\n\n    election = Election(\n        node_id=NodeId(\"ME\"),\n        election_message_receiver=em_in_rx,\n        election_message_sender=em_out_tx,\n        election_result_sender=er_tx,\n        connection_message_receiver=cm_rx,\n        command_receiver=co_rx,\n        is_candidate=True,\n    )\n\n    async with create_task_group() as tg:\n        with fail_after(2):\n            tg.start_soon(election.run)\n\n            # Start round with peer's message (higher seniority)\n            await em_in_tx.send(em(clock=1, seniority=10, node_id=\"PEER\"))\n\n            # We should still broadcast our status exactly once for this round\n            while True:\n                got = await em_out_rx.receive()\n                if got.clock == 1:\n                    assert got.seniority == 0\n                    break\n\n            # After the timeout, election result for clock=1 should report the peer as master\n            # (Skip any earlier result from the boot campaign at clock=0 by filtering on election_clock)\n            while True:\n                result = await er_rx.receive()\n                if result.session_id.election_clock == 1:\n                    break\n\n            assert result.session_id.master_node_id == NodeId(\"PEER\")\n            assert result.is_new_master is True\n\n            em_in_tx.close()\n            cm_tx.close()\n            co_tx.close()\n\n    # We lost → seniority unchanged\n    assert election.seniority == 0\n\n\n@pytest.mark.anyio\nasync def test_ignores_older_messages() -> None:\n    \"\"\"\n    Messages with a lower clock than the current round are ignored by the receiver.\n    Expect exactly one broadcast for the higher clock round.\n    \"\"\"\n    em_out_tx, em_out_rx = channel[ElectionMessage]()\n    em_in_tx, em_in_rx = channel[ElectionMessage]()\n    er_tx, _er_rx = channel[ElectionResult]()\n    cm_tx, cm_rx = channel[ConnectionMessage]()\n    co_tx, co_rx = channel[ForwarderCommand]()\n\n    election = Election(\n        node_id=NodeId(\"ME\"),\n        election_message_receiver=em_in_rx,\n        election_message_sender=em_out_tx,\n        election_result_sender=er_tx,\n        connection_message_receiver=cm_rx,\n        command_receiver=co_rx,\n        is_candidate=True,\n    )\n\n    async with create_task_group() as tg:\n        with fail_after(2):\n            tg.start_soon(election.run)\n\n            # Newer round arrives first -> triggers campaign at clock=2\n            await em_in_tx.send(em(clock=2, seniority=0, node_id=\"A\"))\n            while True:\n                first = await em_out_rx.receive()\n                if first.clock == 2:\n                    break\n\n            # Older message (clock=1) must be ignored (no second broadcast)\n            await em_in_tx.send(em(clock=1, seniority=999, node_id=\"B\"))\n\n            got_second = False\n            with move_on_after(0.05):\n                _ = await em_out_rx.receive()\n                got_second = True\n            assert not got_second, \"Should not receive a broadcast for an older round\"\n\n            em_in_tx.close()\n            cm_tx.close()\n            co_tx.close()\n\n    # Not asserting on the result; focus is on ignore behavior.\n\n\n@pytest.mark.anyio\nasync def test_two_rounds_emit_two_broadcasts_and_increment_clock() -> None:\n    \"\"\"\n    Two successive rounds → two broadcasts. Second round triggered by a higher-clock message.\n    \"\"\"\n    em_out_tx, em_out_rx = channel[ElectionMessage]()\n    em_in_tx, em_in_rx = channel[ElectionMessage]()\n    er_tx, _er_rx = channel[ElectionResult]()\n    cm_tx, cm_rx = channel[ConnectionMessage]()\n    co_tx, co_rx = channel[ForwarderCommand]()\n\n    election = Election(\n        node_id=NodeId(\"ME\"),\n        election_message_receiver=em_in_rx,\n        election_message_sender=em_out_tx,\n        election_result_sender=er_tx,\n        connection_message_receiver=cm_rx,\n        command_receiver=co_rx,\n        is_candidate=True,\n    )\n\n    async with create_task_group() as tg:\n        with fail_after(2):\n            tg.start_soon(election.run)\n\n            # Round 1 at clock=1\n            await em_in_tx.send(em(clock=1, seniority=0, node_id=\"X\"))\n            while True:\n                m1 = await em_out_rx.receive()\n                if m1.clock == 1:\n                    break\n\n            # Round 2 at clock=2\n            await em_in_tx.send(em(clock=2, seniority=0, node_id=\"Y\"))\n            while True:\n                m2 = await em_out_rx.receive()\n                if m2.clock == 2:\n                    break\n\n            em_in_tx.close()\n            cm_tx.close()\n            co_tx.close()\n\n    # Not asserting on who won; just that both rounds were broadcast.\n\n\n@pytest.mark.anyio\nasync def test_promotion_new_seniority_counts_participants() -> None:\n    \"\"\"\n    When we win against two peers in the same round, our seniority becomes\n    max(existing, number_of_candidates). With existing=0: expect 3 (us + A + B).\n    \"\"\"\n    em_out_tx, em_out_rx = channel[ElectionMessage]()\n    em_in_tx, em_in_rx = channel[ElectionMessage]()\n    er_tx, er_rx = channel[ElectionResult]()\n    cm_tx, cm_rx = channel[ConnectionMessage]()\n    co_tx, co_rx = channel[ForwarderCommand]()\n\n    election = Election(\n        node_id=NodeId(\"ME\"),\n        election_message_receiver=em_in_rx,\n        election_message_sender=em_out_tx,\n        election_result_sender=er_tx,\n        connection_message_receiver=cm_rx,\n        command_receiver=co_rx,\n        is_candidate=True,\n    )\n\n    async with create_task_group() as tg:\n        with fail_after(2):\n            tg.start_soon(election.run)\n\n            # Start round at clock=7 with two peer participants\n            await em_in_tx.send(em(clock=7, seniority=0, node_id=\"A\"))\n            await em_in_tx.send(em(clock=7, seniority=0, node_id=\"B\"))\n\n            # We should see exactly one broadcast from us for this round\n            while True:\n                got = await em_out_rx.receive()\n                if got.clock == 7 and got.proposed_session.master_node_id == NodeId(\n                    \"ME\"\n                ):\n                    break\n\n            # Wait for the election to finish so seniority updates\n            _ = await er_rx.receive()\n\n            em_in_tx.close()\n            cm_tx.close()\n            co_tx.close()\n\n    # We + A + B = 3 → new seniority expected to be 3\n    assert election.seniority == 3\n\n\n@pytest.mark.anyio\nasync def test_connection_message_triggers_new_round_broadcast() -> None:\n    \"\"\"\n    A connection message increments the clock and starts a new campaign.\n    We should observe a broadcast at the incremented clock.\n    \"\"\"\n    em_out_tx, em_out_rx = channel[ElectionMessage]()\n    em_in_tx, em_in_rx = channel[ElectionMessage]()\n    er_tx, _er_rx = channel[ElectionResult]()\n    cm_tx, cm_rx = channel[ConnectionMessage]()\n    co_tx, co_rx = channel[ForwarderCommand]()\n\n    election = Election(\n        node_id=NodeId(\"ME\"),\n        election_message_receiver=em_in_rx,\n        election_message_sender=em_out_tx,\n        election_result_sender=er_tx,\n        connection_message_receiver=cm_rx,\n        command_receiver=co_rx,\n        is_candidate=True,\n    )\n\n    async with create_task_group() as tg:\n        with fail_after(2):\n            tg.start_soon(election.run)\n\n            # Send any connection message object; we close quickly to cancel before result creation\n            await cm_tx.send(ConnectionMessage(node_id=NodeId(), connected=True))\n\n            # Expect a broadcast for the new round at clock=1\n            while True:\n                got = await em_out_rx.receive()\n                if got.clock == 1 and got.proposed_session.master_node_id == NodeId(\n                    \"ME\"\n                ):\n                    break\n\n            # Close promptly to avoid waiting for campaign completion\n            em_in_tx.close()\n            cm_tx.close()\n            co_tx.close()\n\n    # After cancellation (before election finishes), no seniority changes asserted here.\n\n\n@pytest.mark.anyio\nasync def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:\n    \"\"\"\n    With equal seniority, the node that has seen more commands should win the election.\n    We increase our local 'commands_seen' by sending TestCommand()s before triggering the round.\n    \"\"\"\n    em_out_tx, em_out_rx = channel[ElectionMessage]()\n    em_in_tx, em_in_rx = channel[ElectionMessage]()\n    er_tx, er_rx = channel[ElectionResult]()\n    cm_tx, cm_rx = channel[ConnectionMessage]()\n    co_tx, co_rx = channel[ForwarderCommand]()\n\n    me = NodeId(\"ME\")\n\n    election = Election(\n        node_id=me,\n        election_message_receiver=em_in_rx,\n        election_message_sender=em_out_tx,\n        election_result_sender=er_tx,\n        connection_message_receiver=cm_rx,\n        command_receiver=co_rx,\n        is_candidate=True,\n        seniority=0,\n    )\n\n    async with create_task_group() as tg:\n        with fail_after(2):\n            tg.start_soon(election.run)\n\n            # Pump local commands so our commands_seen is high before the round starts\n            for _ in range(50):\n                await co_tx.send(\n                    ForwarderCommand(origin=SystemId(\"SOMEONE\"), command=TestCommand())\n                )\n\n            # Trigger a round at clock=1 with a peer of equal seniority but fewer commands\n            await em_in_tx.send(\n                em(clock=1, seniority=0, node_id=\"PEER\", commands_seen=5)\n            )\n\n            # Observe our broadcast for this round (to ensure we've joined the round)\n            while True:\n                got = await em_out_rx.receive()\n                if got.clock == 1 and got.proposed_session.master_node_id == me:\n                    # We don't assert exact count, just that we've participated this round.\n                    break\n\n            # The elected result for clock=1 should be us due to higher commands_seen\n            while True:\n                result = await er_rx.receive()\n                if result.session_id.master_node_id == me:\n                    assert result.session_id.election_clock in (0, 1)\n                    break\n\n            em_in_tx.close()\n            cm_tx.close()\n            co_tx.close()\n"
  },
  {
    "path": "src/exo/shared/tests/test_node_id_persistence.py",
    "content": "import contextlib\nimport multiprocessing\nimport os\nfrom multiprocessing import Event, Queue, Semaphore\nfrom multiprocessing.process import BaseProcess\nfrom multiprocessing.queues import Queue as QueueT\nfrom multiprocessing.synchronize import Event as EventT\nfrom multiprocessing.synchronize import Semaphore as SemaphoreT\n\nfrom loguru import logger\nfrom pytest import LogCaptureFixture, mark\n\nfrom exo.routing.router import get_node_id_keypair\nfrom exo.shared.constants import EXO_NODE_ID_KEYPAIR\n\nNUM_CONCURRENT_PROCS = 10\n\n\ndef _get_keypair_concurrent_subprocess_task(\n    sem: SemaphoreT, ev: EventT, queue: QueueT[bytes]\n) -> None:\n    # synchronise with parent process\n    sem.release()\n    # wait to be told to begin simultaneous read\n    ev.wait()\n    queue.put(get_node_id_keypair().to_bytes())\n\n\ndef _get_keypair_concurrent(num_procs: int) -> bytes:\n    assert num_procs > 0\n\n    sem = Semaphore(0)\n    ev = Event()\n    queue: QueueT[bytes] = Queue(maxsize=num_procs)\n\n    # make parent process wait for all subprocesses to start\n    logger.info(f\"PARENT: Starting {num_procs} subprocesses\")\n    ps: list[BaseProcess] = []\n    for _ in range(num_procs):\n        p = multiprocessing.get_context(\"fork\").Process(\n            target=_get_keypair_concurrent_subprocess_task, args=(sem, ev, queue)\n        )\n        ps.append(p)\n        p.start()\n    for _ in range(num_procs):\n        sem.acquire()\n\n    # start all the sub processes simultaneously\n    logger.info(\"PARENT: Beginning read\")\n    ev.set()\n\n    # wait until all subprocesses are done & read results\n    for p in ps:\n        p.join()\n\n    # check that the input/output order match, and that\n    # all subprocesses end up reading the same file\n    logger.info(\"PARENT: Checking consistency\")\n    keypair: bytes | None = None\n    qsize = 0  # cannot use Queue.qsize due to MacOS incompatibility :(\n    while not queue.empty():\n        qsize += 1\n        temp_keypair = queue.get()\n        if keypair is None:\n            keypair = temp_keypair\n        else:\n            assert keypair == temp_keypair\n    assert num_procs == qsize\n    return keypair  # pyright: ignore[reportReturnType]\n\n\ndef _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):\n    with contextlib.suppress(OSError):\n        os.remove(p)\n\n\n@mark.skip(reason=\"this functionality is currently disabled but may return in future\")\ndef test_node_id_fetching(caplog: LogCaptureFixture):\n    reps = 10\n\n    # delete current file and write a new one\n    _delete_if_exists(EXO_NODE_ID_KEYPAIR)\n    kp = _get_keypair_concurrent(NUM_CONCURRENT_PROCS)\n\n    with caplog.at_level(101):  # supress logs\n        # make sure that continuous fetches return the same value\n        for _ in range(reps):\n            assert kp == _get_keypair_concurrent(NUM_CONCURRENT_PROCS)\n\n        # make sure that after deleting, we are not fetching the same value\n        _delete_if_exists(EXO_NODE_ID_KEYPAIR)\n        for _ in range(reps):\n            assert kp != _get_keypair_concurrent(NUM_CONCURRENT_PROCS)\n"
  },
  {
    "path": "src/exo/shared/tests/test_resolve_reasoning_params.py",
    "content": "import pytest\n\nfrom exo.shared.types.text_generation import resolve_reasoning_params\n\n\ndef test_both_none_returns_none_none() -> None:\n    assert resolve_reasoning_params(None, None) == (None, None)\n\n\ndef test_both_set_passes_through_unchanged() -> None:\n    assert resolve_reasoning_params(\"high\", True) == (\"high\", True)\n    assert resolve_reasoning_params(\"none\", True) == (\"none\", True)\n    assert resolve_reasoning_params(\"low\", False) == (\"low\", False)\n\n\ndef test_enable_thinking_true_derives_medium() -> None:\n    assert resolve_reasoning_params(None, True) == (\"medium\", True)\n\n\ndef test_enable_thinking_false_derives_none() -> None:\n    assert resolve_reasoning_params(None, False) == (\"none\", False)\n\n\ndef test_reasoning_effort_none_derives_thinking_false() -> None:\n    assert resolve_reasoning_params(\"none\", None) == (\"none\", False)\n\n\n@pytest.mark.parametrize(\"effort\", [\"minimal\", \"low\", \"medium\", \"high\", \"xhigh\"])\ndef test_non_none_effort_derives_thinking_true(effort: str) -> None:\n    assert resolve_reasoning_params(effort, None) == (effort, True)  # pyright: ignore[reportArgumentType]\n"
  },
  {
    "path": "src/exo/shared/tests/test_state_serialization.py",
    "content": "from exo.shared.types.common import NodeId\nfrom exo.shared.types.multiaddr import Multiaddr\nfrom exo.shared.types.state import State\nfrom exo.shared.types.topology import Connection, SocketConnection\n\n\ndef test_state_serialization_roundtrip() -> None:\n    \"\"\"Verify that State → JSON → State round-trip preserves topology.\"\"\"\n\n    # --- build a simple state ------------------------------------------------\n    node_a = NodeId(\"node-a\")\n    node_b = NodeId(\"node-b\")\n\n    connection = Connection(\n        source=node_a,\n        sink=node_b,\n        edge=SocketConnection(\n            sink_multiaddr=Multiaddr(address=\"/ip4/127.0.0.1/tcp/10001\"),\n        ),\n    )\n\n    state = State()\n    state.topology.add_connection(connection)\n\n    json_repr = state.model_dump_json()\n    restored_state = State.model_validate_json(json_repr)\n\n    assert (\n        state.topology.to_snapshot().nodes\n        == restored_state.topology.to_snapshot().nodes\n    )\n    assert set(state.topology.to_snapshot().connections) == set(\n        restored_state.topology.to_snapshot().connections\n    )\n    assert restored_state.model_dump_json() == json_repr\n"
  },
  {
    "path": "src/exo/shared/tests/test_xdg_paths.py",
    "content": "\"\"\"Tests for XDG Base Directory Specification compliance.\"\"\"\n\nimport os\nimport sys\nfrom pathlib import Path\nfrom unittest import mock\n\n\ndef test_xdg_paths_on_linux():\n    \"\"\"Test that XDG paths are used on Linux when XDG env vars are set.\"\"\"\n    with (\n        mock.patch.dict(\n            os.environ,\n            {\n                \"XDG_CONFIG_HOME\": \"/tmp/test-config\",\n                \"XDG_DATA_HOME\": \"/tmp/test-data\",\n                \"XDG_CACHE_HOME\": \"/tmp/test-cache\",\n            },\n            clear=False,\n        ),\n        mock.patch.object(sys, \"platform\", \"linux\"),\n    ):\n        # Re-import to pick up mocked values\n        import importlib\n\n        import exo.shared.constants as constants\n\n        importlib.reload(constants)\n\n        assert Path(\"/tmp/test-config/exo\") == constants.EXO_CONFIG_HOME\n        assert Path(\"/tmp/test-data/exo\") == constants.EXO_DATA_HOME\n        assert Path(\"/tmp/test-cache/exo\") == constants.EXO_CACHE_HOME\n\n\ndef test_xdg_default_paths_on_linux():\n    \"\"\"Test that XDG default paths are used on Linux when env vars are not set.\"\"\"\n    # Remove XDG env vars and EXO_HOME\n    env = {\n        k: v\n        for k, v in os.environ.items()\n        if not k.startswith(\"XDG_\") and k != \"EXO_HOME\"\n    }\n    with (\n        mock.patch.dict(os.environ, env, clear=True),\n        mock.patch.object(sys, \"platform\", \"linux\"),\n    ):\n        import importlib\n\n        import exo.shared.constants as constants\n\n        importlib.reload(constants)\n\n        home = Path.home()\n        assert home / \".config\" / \"exo\" == constants.EXO_CONFIG_HOME\n        assert home / \".local/share\" / \"exo\" == constants.EXO_DATA_HOME\n        assert home / \".cache\" / \"exo\" == constants.EXO_CACHE_HOME\n\n\ndef test_legacy_exo_home_takes_precedence():\n    \"\"\"Test that EXO_HOME environment variable takes precedence for backward compatibility.\"\"\"\n    with mock.patch.dict(\n        os.environ,\n        {\n            \"EXO_HOME\": \".custom-exo\",\n            \"XDG_CONFIG_HOME\": \"/tmp/test-config\",\n        },\n        clear=False,\n    ):\n        import importlib\n\n        import exo.shared.constants as constants\n\n        importlib.reload(constants)\n\n        home = Path.home()\n        assert home / \".custom-exo\" == constants.EXO_CONFIG_HOME\n        assert home / \".custom-exo\" == constants.EXO_DATA_HOME\n\n\ndef test_macos_uses_traditional_paths():\n    \"\"\"Test that macOS uses traditional ~/.exo directory.\"\"\"\n    # Remove EXO_HOME to ensure we test the default behavior\n    env = {k: v for k, v in os.environ.items() if k != \"EXO_HOME\"}\n    with (\n        mock.patch.dict(os.environ, env, clear=True),\n        mock.patch.object(sys, \"platform\", \"darwin\"),\n    ):\n        import importlib\n\n        import exo.shared.constants as constants\n\n        importlib.reload(constants)\n\n        home = Path.home()\n        assert home / \".exo\" == constants.EXO_CONFIG_HOME\n        assert home / \".exo\" == constants.EXO_DATA_HOME\n        assert home / \".exo\" == constants.EXO_CACHE_HOME\n\n\ndef test_node_id_in_config_dir():\n    \"\"\"Test that node ID keypair is in the config directory.\"\"\"\n    import exo.shared.constants as constants\n\n    assert constants.EXO_NODE_ID_KEYPAIR.parent == constants.EXO_CONFIG_HOME\n\n\ndef test_models_in_data_dir():\n    \"\"\"Test that models directory is in the data directory.\"\"\"\n    # Clear EXO_MODELS_DIR to test default behavior\n    env = {k: v for k, v in os.environ.items() if k != \"EXO_MODELS_DIR\"}\n    with mock.patch.dict(os.environ, env, clear=True):\n        import importlib\n\n        import exo.shared.constants as constants\n\n        importlib.reload(constants)\n\n        assert constants.EXO_MODELS_DIR.parent == constants.EXO_DATA_HOME\n"
  },
  {
    "path": "src/exo/shared/topology.py",
    "content": "import contextlib\nfrom collections.abc import Mapping, Sequence\nfrom dataclasses import dataclass, field\nfrom typing import Iterable\n\nimport rustworkx as rx\nfrom pydantic import BaseModel, ConfigDict\n\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.profiling import (\n    InterfaceType,\n    NodeNetworkInfo,\n    ThunderboltBridgeStatus,\n)\nfrom exo.shared.types.topology import (\n    Connection,\n    Cycle,\n    RDMAConnection,\n    SocketConnection,\n)\n\n\nclass TopologySnapshot(BaseModel):\n    nodes: Sequence[NodeId]\n    connections: Mapping[\n        NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]\n    ]\n\n    model_config = ConfigDict(frozen=True, extra=\"forbid\")\n\n\n@dataclass\nclass Topology:\n    _graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(\n        init=False, default_factory=rx.PyDiGraph\n    )\n    _vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)\n\n    def to_snapshot(self) -> TopologySnapshot:\n        return TopologySnapshot(\n            nodes=list(self.list_nodes()), connections=self.map_connections()\n        )\n\n    @classmethod\n    def from_snapshot(cls, snapshot: TopologySnapshot) -> \"Topology\":\n        topology = cls()\n\n        for node_id in snapshot.nodes:\n            with contextlib.suppress(ValueError):\n                topology.add_node(node_id)\n\n        for source in snapshot.connections:\n            for sink in snapshot.connections[source]:\n                for edge in snapshot.connections[source][sink]:\n                    topology.add_connection(\n                        Connection(source=source, sink=sink, edge=edge)\n                    )\n\n        return topology\n\n    def add_node(self, node_id: NodeId) -> None:\n        if node_id in self._vertex_indices:\n            return\n        rx_id = self._graph.add_node(node_id)\n        self._vertex_indices[node_id] = rx_id\n\n    def node_is_leaf(self, node_id: NodeId) -> bool:\n        return (\n            node_id in self._vertex_indices\n            and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1\n        )\n\n    def neighbours(self, node_id: NodeId) -> list[NodeId]:\n        return [\n            self._graph[rx_id]\n            for rx_id in self._graph.neighbors(self._vertex_indices[node_id])\n        ]\n\n    def out_edges(self, node_id: NodeId) -> Iterable[Connection]:\n        if node_id not in self._vertex_indices:\n            return []\n        return (\n            Connection(source=self._graph[source], sink=self._graph[sink], edge=edge)\n            for source, sink, edge in self._graph.out_edges(\n                self._vertex_indices[node_id]\n            )\n        )\n\n    def contains_node(self, node_id: NodeId) -> bool:\n        return node_id in self._vertex_indices\n\n    def add_connection(self, conn: Connection) -> None:\n        source, sink, edge = conn.source, conn.sink, conn.edge\n        del conn\n        if edge in self.get_all_connections_between(source, sink):\n            return\n\n        if source not in self._vertex_indices:\n            self.add_node(source)\n        if sink not in self._vertex_indices:\n            self.add_node(sink)\n\n        src_id = self._vertex_indices[source]\n        sink_id = self._vertex_indices[sink]\n\n        _ = self._graph.add_edge(src_id, sink_id, edge)\n\n    def get_all_connections_between(\n        self, source: NodeId, sink: NodeId\n    ) -> Iterable[SocketConnection | RDMAConnection]:\n        if source not in self._vertex_indices:\n            return []\n        if sink not in self._vertex_indices:\n            return []\n\n        src_id = self._vertex_indices[source]\n        sink_id = self._vertex_indices[sink]\n        try:\n            return self._graph.get_all_edge_data(src_id, sink_id)\n        except rx.NoEdgeBetweenNodes:\n            return []\n\n    def list_nodes(self) -> Iterable[NodeId]:\n        return self._graph.nodes()\n\n    def map_connections(\n        self,\n    ) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:\n        base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}\n        for src_id, sink_id, connection in self._graph.weighted_edge_list():\n            source = self._graph[src_id]\n            sink = self._graph[sink_id]\n            if source not in base:\n                base[source] = {}\n            if sink not in base[source]:\n                base[source][sink] = []\n            base[source][sink].append(connection)\n        return base\n\n    def list_connections(\n        self,\n    ) -> Iterable[Connection]:\n        return (\n            (\n                Connection(\n                    source=self._graph[src_id],\n                    sink=self._graph[sink_id],\n                    edge=connection,\n                )\n            )\n            for src_id, sink_id, connection in self._graph.weighted_edge_list()\n        )\n\n    def remove_node(self, node_id: NodeId) -> None:\n        if node_id not in self._vertex_indices:\n            return\n\n        rx_idx = self._vertex_indices[node_id]\n        self._graph.remove_node(rx_idx)\n\n        del self._vertex_indices[node_id]\n\n    def replace_all_out_rdma_connections(\n        self, source: NodeId, new_connections: Sequence[Connection]\n    ) -> None:\n        for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):\n            if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):\n                self._graph.remove_edge_from_index(conn_idx)\n        for conn in new_connections:\n            self.add_connection(conn)\n\n    def remove_connection(self, conn: Connection) -> None:\n        if (\n            conn.source not in self._vertex_indices\n            or conn.sink not in self._vertex_indices\n        ):\n            return\n        for conn_idx in self._graph.edge_indices_from_endpoints(\n            self._vertex_indices[conn.source], self._vertex_indices[conn.sink]\n        ):\n            if self._graph.get_edge_data_by_index(conn_idx) == conn.edge:\n                self._graph.remove_edge_from_index(conn_idx)\n\n    def get_cycles(self) -> list[Cycle]:\n        \"\"\"Get simple cycles in the graph, including singleton cycles\"\"\"\n\n        cycle_idxs = rx.simple_cycles(self._graph)\n        cycles: list[Cycle] = []\n        for cycle_idx in cycle_idxs:\n            cycle = Cycle(node_ids=[self._graph[idx] for idx in cycle_idx])\n            cycles.append(cycle)\n        for node_id in self.list_nodes():\n            cycles.append(Cycle(node_ids=[node_id]))\n        return cycles\n\n    def get_rdma_cycles(self) -> list[Cycle]:\n        rdma_edges = [\n            (u, v, conn)\n            for u, v, conn in self._graph.weighted_edge_list()\n            if isinstance(conn, RDMAConnection)\n        ]\n\n        rdma_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = (\n            rx.PyDiGraph()\n        )\n        rdma_graph.add_nodes_from(self._graph.nodes())\n\n        for u, v, conn in rdma_edges:\n            rdma_graph.add_edge(u, v, conn)\n\n        cycle_idxs = rx.simple_cycles(rdma_graph)\n        cycles: list[Cycle] = []\n        for cycle_idx in cycle_idxs:\n            cycle = Cycle(node_ids=[rdma_graph[idx] for idx in cycle_idx])\n            cycles.append(cycle)\n\n        return cycles\n\n    def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> \"Topology\":\n        topology = Topology()\n        for node_id in node_ids:\n            topology.add_node(node_id)\n        for connection in self.list_connections():\n            if connection.source in node_ids and connection.sink in node_ids:\n                topology.add_connection(connection)\n        return topology\n\n    def is_rdma_cycle(self, cycle: Cycle) -> bool:\n        node_idxs = [node for node in cycle]\n        rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]\n        for rid in rx_idxs:\n            for neighbor_rid in self._graph.neighbors(rid):\n                if neighbor_rid not in rx_idxs:\n                    continue\n                has_rdma = False\n                for edge in self._graph.get_all_edge_data(rid, neighbor_rid):\n                    if isinstance(edge, RDMAConnection):\n                        has_rdma = True\n                        break\n                if not has_rdma:\n                    return False\n        return True\n\n    def get_thunderbolt_bridge_cycles(\n        self,\n        node_tb_bridge_status: Mapping[NodeId, ThunderboltBridgeStatus],\n        node_network: Mapping[NodeId, NodeNetworkInfo],\n    ) -> list[list[NodeId]]:\n        \"\"\"\n        Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.\n        Only returns cycles with >=2 nodes (2+ machines in a loop), as\n        1 node doesn't cause the broadcast storm problem.\n        \"\"\"\n        enabled_nodes = {\n            node_id\n            for node_id, status in node_tb_bridge_status.items()\n            if status.enabled\n        }\n\n        if len(enabled_nodes) < 2:\n            return []\n\n        thunderbolt_ips = _get_ips_with_interface_type(\n            enabled_nodes, node_network, \"thunderbolt\"\n        )\n\n        # Build subgraph with only TB bridge enabled nodes and thunderbolt connections\n        graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = rx.PyDiGraph()\n        node_to_idx: dict[NodeId, int] = {}\n\n        for node_id in enabled_nodes:\n            if node_id in self._vertex_indices:\n                node_to_idx[node_id] = graph.add_node(node_id)\n\n        for u, v, conn in self._graph.weighted_edge_list():\n            source_id, sink_id = self._graph[u], self._graph[v]\n            if source_id not in node_to_idx or sink_id not in node_to_idx:\n                continue\n            # Include connection if it's over a thunderbolt interface\n            if (\n                isinstance(conn, SocketConnection)\n                and conn.sink_multiaddr.ip_address in thunderbolt_ips\n            ):\n                graph.add_edge(node_to_idx[source_id], node_to_idx[sink_id], conn)\n            if isinstance(conn, RDMAConnection):\n                graph.add_edge(node_to_idx[source_id], node_to_idx[sink_id], conn)\n\n        return [\n            [graph[idx] for idx in cycle]\n            for cycle in rx.simple_cycles(graph)\n            if len(cycle) >= 2\n        ]\n\n\ndef _get_ips_with_interface_type(\n    node_ids: set[NodeId],\n    node_network: Mapping[NodeId, NodeNetworkInfo],\n    interface_type: InterfaceType,\n) -> set[str]:\n    \"\"\"Get all IP addresses on interfaces of the specified type for the given nodes.\"\"\"\n    ips: set[str] = set()\n    for node_id in node_ids:\n        network_info = node_network.get(node_id, NodeNetworkInfo())\n        for iface in network_info.interfaces:\n            if iface.interface_type == interface_type:\n                ips.add(iface.ip_address)\n    return ips\n"
  },
  {
    "path": "src/exo/shared/tracing.py",
    "content": "from __future__ import annotations\n\nimport json\nimport time\nfrom collections import defaultdict\nfrom collections.abc import Generator\nfrom contextlib import contextmanager\nfrom contextvars import ContextVar\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import cast, final\n\nfrom exo.shared.constants import EXO_TRACING_ENABLED\nfrom exo.worker.runner.bootstrap import logger\n\n# Context variable to track the current trace category for hierarchical nesting\n_current_category: ContextVar[str | None] = ContextVar(\"current_category\", default=None)\n\n\n@final\n@dataclass(frozen=True)\nclass TraceEvent:\n    name: str\n    start_us: int\n    duration_us: int\n    rank: int\n    category: str\n\n\n@final\n@dataclass\nclass CategoryStats:\n    total_us: int = 0\n    count: int = 0\n    min_us: int = 0\n    max_us: int = 0\n\n    def add(self, duration_us: int) -> None:\n        if self.count == 0:\n            self.min_us = duration_us\n            self.max_us = duration_us\n        else:\n            self.min_us = min(self.min_us, duration_us)\n            self.max_us = max(self.max_us, duration_us)\n        self.total_us += duration_us\n        self.count += 1\n\n    @property\n    def avg_us(self) -> float:\n        return self.total_us / self.count if self.count > 0 else 0.0\n\n\n@final\n@dataclass\nclass TraceStats:\n    total_wall_time_us: int = 0\n    by_category: dict[str, CategoryStats] = field(default_factory=dict)\n    by_rank: dict[int, dict[str, CategoryStats]] = field(default_factory=dict)\n\n\n# Global trace buffer - each rank accumulates traces here\n_trace_buffer: list[TraceEvent] = []\n\n\ndef _record_span(\n    name: str, start_us: int, duration_us: int, rank: int, category: str\n) -> None:\n    _trace_buffer.append(\n        TraceEvent(\n            name=name,\n            start_us=start_us,\n            duration_us=duration_us,\n            rank=rank,\n            category=category,\n        )\n    )\n\n\n@contextmanager\ndef trace(\n    name: str,\n    rank: int,\n    category: str = \"compute\",\n) -> Generator[None, None, None]:\n    \"\"\"Context manager to trace any operation.\n\n    Nested traces automatically inherit the parent category, creating hierarchical\n    categories like \"sync/compute\" or \"async/comms\".\n\n    Args:\n        name: Name of the operation (e.g., \"recv 0\", \"send 1\", \"joint_blocks\")\n        rank: This rank's ID\n        category: Category for grouping in trace viewer (\"comm\", \"compute\", \"step\")\n\n    Example:\n        with trace(f\"sync {t}\", rank, \"sync\"):\n            with trace(\"joint_blocks\", rank, \"compute\"):\n                # Recorded with category \"sync/compute\"\n                hidden_states = some_computation(...)\n    \"\"\"\n    if not EXO_TRACING_ENABLED:\n        yield\n        return\n\n    # Combine with parent category if nested\n    parent = _current_category.get()\n    full_category = f\"{parent}/{category}\" if parent else category\n\n    # Set as current for nested traces\n    token = _current_category.set(full_category)\n\n    try:\n        start_us = int(time.time() * 1_000_000)\n        start_perf = time.perf_counter()\n        yield\n        duration_us = int((time.perf_counter() - start_perf) * 1_000_000)\n        _record_span(name, start_us, duration_us, rank, full_category)\n    finally:\n        _current_category.reset(token)\n\n\ndef get_trace_buffer() -> list[TraceEvent]:\n    return list(_trace_buffer)\n\n\ndef clear_trace_buffer() -> None:\n    _trace_buffer.clear()\n\n\ndef export_trace(traces: list[TraceEvent], output_path: Path) -> None:\n    trace_events: list[dict[str, object]] = []\n\n    for event in traces:\n        # Chrome trace format uses \"X\" for complete events (with duration)\n        chrome_event: dict[str, object] = {\n            \"name\": event.name,\n            \"cat\": event.category,\n            \"ph\": \"X\",\n            \"ts\": event.start_us,\n            \"dur\": event.duration_us,\n            \"pid\": 0,\n            \"tid\": event.rank,\n            \"args\": {\"rank\": event.rank},\n        }\n        trace_events.append(chrome_event)\n\n    ranks_seen = set(t.rank for t in traces)\n    for rank in ranks_seen:\n        trace_events.append(\n            {\n                \"name\": \"thread_name\",\n                \"ph\": \"M\",  # Metadata event\n                \"pid\": 0,\n                \"tid\": rank,\n                \"args\": {\"name\": f\"Rank {rank}\"},\n            }\n        )\n\n    chrome_trace = {\"traceEvents\": trace_events}\n\n    try:\n        output_path.parent.mkdir(parents=True, exist_ok=True)\n        with open(output_path, \"w\") as f:\n            json.dump(chrome_trace, f, indent=2)\n    except OSError as e:\n        logger.warning(\"Failed to export trace to %s: %s\", output_path, e)\n\n\ndef load_trace_file(path: Path) -> list[TraceEvent]:\n    with open(path) as f:\n        data = cast(dict[str, list[dict[str, object]]], json.load(f))\n\n    events = data.get(\"traceEvents\", [])\n    traces: list[TraceEvent] = []\n\n    for event in events:\n        # Skip metadata events\n        if event.get(\"ph\") == \"M\":\n            continue\n\n        name = str(event.get(\"name\", \"\"))\n        category = str(event.get(\"cat\", \"\"))\n        ts_value = event.get(\"ts\", 0)\n        dur_value = event.get(\"dur\", 0)\n        tid_value = event.get(\"tid\", 0)\n        start_us = int(ts_value) if isinstance(ts_value, (int, float, str)) else 0\n        duration_us = int(dur_value) if isinstance(dur_value, (int, float, str)) else 0\n\n        # Get rank from tid or args\n        rank = int(tid_value) if isinstance(tid_value, (int, float, str)) else 0\n        args = event.get(\"args\")\n        if isinstance(args, dict):\n            args_dict = cast(dict[str, object], args)\n            rank_from_args = args_dict.get(\"rank\")\n            if isinstance(rank_from_args, (int, float, str)):\n                rank = int(rank_from_args)\n\n        traces.append(\n            TraceEvent(\n                name=name,\n                start_us=start_us,\n                duration_us=duration_us,\n                rank=rank,\n                category=category,\n            )\n        )\n\n    return traces\n\n\ndef compute_stats(traces: list[TraceEvent]) -> TraceStats:\n    stats = TraceStats()\n\n    if not traces:\n        return stats\n\n    # Calculate wall time from earliest start to latest end\n    min_start = min(t.start_us for t in traces)\n    max_end = max(t.start_us + t.duration_us for t in traces)\n    stats.total_wall_time_us = max_end - min_start\n\n    # Initialize nested dicts\n    by_category: dict[str, CategoryStats] = defaultdict(CategoryStats)\n    by_rank: dict[int, dict[str, CategoryStats]] = defaultdict(\n        lambda: defaultdict(CategoryStats)\n    )\n\n    for event in traces:\n        # By category\n        by_category[event.category].add(event.duration_us)\n\n        # By rank and category\n        by_rank[event.rank][event.category].add(event.duration_us)\n\n    stats.by_category = dict(by_category)\n    stats.by_rank = {k: dict(v) for k, v in by_rank.items()}\n\n    return stats\n"
  },
  {
    "path": "src/exo/shared/types/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/shared/types/chunks.py",
    "content": "from collections.abc import Generator\nfrom typing import Any, Literal\n\nfrom exo.api.types import (\n    FinishReason,\n    GenerationStats,\n    ImageGenerationStats,\n    ToolCallItem,\n    TopLogprobItem,\n    Usage,\n)\nfrom exo.shared.models.model_cards import ModelId\nfrom exo.utils.pydantic_ext import TaggedModel\n\nfrom .common import CommandId\n\n\nclass BaseChunk(TaggedModel):\n    model: ModelId\n\n\nclass TokenChunk(BaseChunk):\n    text: str\n    token_id: int\n    usage: Usage | None\n    finish_reason: Literal[\"stop\", \"length\", \"content_filter\"] | None = None\n    stats: GenerationStats | None = None\n    logprob: float | None = None\n    top_logprobs: list[TopLogprobItem] | None = None\n    is_thinking: bool = False\n\n\nclass ErrorChunk(BaseChunk):\n    error_message: str\n    finish_reason: Literal[\"error\"] = \"error\"\n\n\nclass ToolCallChunk(BaseChunk):\n    tool_calls: list[ToolCallItem]\n    usage: Usage | None\n    finish_reason: Literal[\"tool_calls\"] = \"tool_calls\"\n    stats: GenerationStats | None = None\n\n\nclass ImageChunk(BaseChunk):\n    data: str\n    chunk_index: int\n    total_chunks: int\n    image_index: int\n    is_partial: bool = False\n    partial_index: int | None = None\n    total_partials: int | None = None\n    stats: ImageGenerationStats | None = None\n    format: Literal[\"png\", \"jpeg\", \"webp\"] | None = None\n    finish_reason: FinishReason | None = None\n    error_message: str | None = None\n\n    def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:\n        for name, value in super().__repr_args__():  # pyright: ignore[reportAny]\n            if name == \"data\" and hasattr(value, \"__len__\"):  # pyright: ignore[reportAny]\n                yield name, f\"<{len(self.data)} chars>\"\n            elif name is not None:\n                yield name, value\n\n\nclass InputImageChunk(BaseChunk):\n    command_id: CommandId\n    data: str\n    chunk_index: int\n    total_chunks: int\n\n    def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:\n        for name, value in super().__repr_args__():  # pyright: ignore[reportAny]\n            if name == \"data\" and hasattr(value, \"__len__\"):  # pyright: ignore[reportAny]\n                yield name, f\"<{len(self.data)} chars>\"\n            elif name is not None:\n                yield name, value\n\n\nclass PrefillProgressChunk(BaseChunk):\n    \"\"\"Data class for prefill progress events during streaming.\"\"\"\n\n    processed_tokens: int\n    total_tokens: int\n\n\nGenerationChunk = (\n    TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk | PrefillProgressChunk\n)\n"
  },
  {
    "path": "src/exo/shared/types/commands.py",
    "content": "from pydantic import Field\n\nfrom exo.api.types import (\n    ImageEditsTaskParams,\n    ImageGenerationTaskParams,\n)\nfrom exo.shared.models.model_cards import ModelCard, ModelId\nfrom exo.shared.types.chunks import InputImageChunk\nfrom exo.shared.types.common import CommandId, NodeId, SystemId\nfrom exo.shared.types.text_generation import TextGenerationTaskParams\nfrom exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta\nfrom exo.shared.types.worker.shards import Sharding, ShardMetadata\nfrom exo.utils.pydantic_ext import CamelCaseModel, TaggedModel\n\n\nclass BaseCommand(TaggedModel):\n    command_id: CommandId = Field(default_factory=CommandId)\n\n\nclass TestCommand(BaseCommand):\n    __test__ = False\n\n\nclass TextGeneration(BaseCommand):\n    task_params: TextGenerationTaskParams\n\n\nclass ImageGeneration(BaseCommand):\n    task_params: ImageGenerationTaskParams\n\n\nclass ImageEdits(BaseCommand):\n    task_params: ImageEditsTaskParams\n\n\nclass PlaceInstance(BaseCommand):\n    model_card: ModelCard\n    sharding: Sharding\n    instance_meta: InstanceMeta\n    min_nodes: int\n\n\nclass CreateInstance(BaseCommand):\n    instance: Instance\n\n\nclass DeleteInstance(BaseCommand):\n    instance_id: InstanceId\n\n\nclass TaskCancelled(BaseCommand):\n    cancelled_command_id: CommandId\n\n\nclass TaskFinished(BaseCommand):\n    finished_command_id: CommandId\n\n\nclass SendInputChunk(BaseCommand):\n    \"\"\"Command to send an input image chunk (converted to event by master).\"\"\"\n\n    chunk: InputImageChunk\n\n\nclass RequestEventLog(BaseCommand):\n    since_idx: int\n\n\nclass StartDownload(BaseCommand):\n    target_node_id: NodeId\n    shard_metadata: ShardMetadata\n\n\nclass DeleteDownload(BaseCommand):\n    target_node_id: NodeId\n    model_id: ModelId\n\n\nclass CancelDownload(BaseCommand):\n    target_node_id: NodeId\n    model_id: ModelId\n\n\nDownloadCommand = StartDownload | DeleteDownload | CancelDownload\n\n\nCommand = (\n    TestCommand\n    | RequestEventLog\n    | TextGeneration\n    | ImageGeneration\n    | ImageEdits\n    | PlaceInstance\n    | CreateInstance\n    | DeleteInstance\n    | TaskCancelled\n    | TaskFinished\n    | SendInputChunk\n)\n\n\nclass ForwarderCommand(CamelCaseModel):\n    origin: SystemId\n    command: Command\n\n\nclass ForwarderDownloadCommand(CamelCaseModel):\n    origin: SystemId\n    command: DownloadCommand\n"
  },
  {
    "path": "src/exo/shared/types/common.py",
    "content": "from typing import Self\nfrom uuid import uuid4\n\nfrom pydantic import GetCoreSchemaHandler, field_validator\nfrom pydantic_core import core_schema\n\nfrom exo.utils.pydantic_ext import CamelCaseModel\n\n\nclass Id(str):\n    def __new__(cls, value: str | None = None) -> Self:\n        return super().__new__(cls, value or str(uuid4()))\n\n    @classmethod\n    def __get_pydantic_core_schema__(\n        cls, _source: type, handler: GetCoreSchemaHandler\n    ) -> core_schema.CoreSchema:\n        # Just use a plain string schema\n        return core_schema.no_info_after_validator_function(\n            cls, core_schema.str_schema()\n        )\n\n\nclass NodeId(Id):\n    pass\n\n\nclass SystemId(Id):\n    pass\n\n\nclass ModelId(Id):\n    def normalize(self) -> str:\n        return self.replace(\"/\", \"--\")\n\n    def short(self) -> str:\n        return self.split(\"/\")[-1]\n\n\nclass SessionId(CamelCaseModel):\n    master_node_id: NodeId\n    election_clock: int\n\n\nclass CommandId(Id):\n    pass\n\n\nclass Host(CamelCaseModel):\n    ip: str\n    port: int\n\n    def __str__(self) -> str:\n        return f\"{self.ip}:{self.port}\"\n\n    @field_validator(\"port\")\n    @classmethod\n    def check_port(cls, v: int) -> int:\n        if not (0 <= v <= 65535):\n            raise ValueError(\"Port must be between 0 and 65535\")\n        return v\n"
  },
  {
    "path": "src/exo/shared/types/events.py",
    "content": "from datetime import datetime\nfrom typing import final\n\nfrom pydantic import Field\n\nfrom exo.shared.topology import Connection\nfrom exo.shared.types.chunks import GenerationChunk, InputImageChunk\nfrom exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId\nfrom exo.shared.types.tasks import Task, TaskId, TaskStatus\nfrom exo.shared.types.worker.downloads import DownloadProgress\nfrom exo.shared.types.worker.instances import Instance, InstanceId\nfrom exo.shared.types.worker.runners import RunnerId, RunnerStatus\nfrom exo.utils.info_gatherer.info_gatherer import GatheredInfo\nfrom exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel\n\n\nclass EventId(Id):\n    \"\"\"\n    Newtype around `ID`\n    \"\"\"\n\n\nclass BaseEvent(TaggedModel):\n    event_id: EventId = Field(default_factory=EventId)\n    # Internal, for debugging. Please don't rely on this field for anything!\n    _master_time_stamp: None | datetime = None\n\n\nclass TestEvent(BaseEvent):\n    __test__ = False\n\n\nclass TaskCreated(BaseEvent):\n    task_id: TaskId\n    task: Task\n\n\nclass TaskAcknowledged(BaseEvent):\n    task_id: TaskId\n\n\nclass TaskDeleted(BaseEvent):\n    task_id: TaskId\n\n\nclass TaskStatusUpdated(BaseEvent):\n    task_id: TaskId\n    task_status: TaskStatus\n\n\nclass TaskFailed(BaseEvent):\n    task_id: TaskId\n    error_type: str\n    error_message: str\n\n\nclass InstanceCreated(BaseEvent):\n    instance: Instance\n\n    def __eq__(self, other: object) -> bool:\n        if isinstance(other, InstanceCreated):\n            return self.instance == other.instance and self.event_id == other.event_id\n\n        return False\n\n\nclass InstanceDeleted(BaseEvent):\n    instance_id: InstanceId\n\n\nclass RunnerStatusUpdated(BaseEvent):\n    runner_id: RunnerId\n    runner_status: RunnerStatus\n\n\nclass NodeTimedOut(BaseEvent):\n    node_id: NodeId\n\n\n# TODO: bikeshed this name\nclass NodeGatheredInfo(BaseEvent):\n    node_id: NodeId\n    when: str  # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device\n    info: GatheredInfo\n\n\nclass NodeDownloadProgress(BaseEvent):\n    download_progress: DownloadProgress\n\n\nclass ChunkGenerated(BaseEvent):\n    command_id: CommandId\n    chunk: GenerationChunk\n\n\nclass InputChunkReceived(BaseEvent):\n    command_id: CommandId\n    chunk: InputImageChunk\n\n\nclass TopologyEdgeCreated(BaseEvent):\n    conn: Connection\n\n\nclass TopologyEdgeDeleted(BaseEvent):\n    conn: Connection\n\n\n@final\nclass TraceEventData(FrozenModel):\n    name: str\n    start_us: int\n    duration_us: int\n    rank: int\n    category: str\n\n\n@final\nclass TracesCollected(BaseEvent):\n    task_id: TaskId\n    rank: int\n    traces: list[TraceEventData]\n\n\n@final\nclass TracesMerged(BaseEvent):\n    task_id: TaskId\n    traces: list[TraceEventData]\n\n\nEvent = (\n    TestEvent\n    | TaskCreated\n    | TaskStatusUpdated\n    | TaskFailed\n    | TaskDeleted\n    | TaskAcknowledged\n    | InstanceCreated\n    | InstanceDeleted\n    | RunnerStatusUpdated\n    | NodeTimedOut\n    | NodeGatheredInfo\n    | NodeDownloadProgress\n    | ChunkGenerated\n    | InputChunkReceived\n    | TopologyEdgeCreated\n    | TopologyEdgeDeleted\n    | TracesCollected\n    | TracesMerged\n)\n\n\nclass IndexedEvent(CamelCaseModel):\n    \"\"\"An event indexed by the master, with a globally unique index\"\"\"\n\n    idx: int = Field(ge=0)\n    event: Event\n\n\nclass GlobalForwarderEvent(CamelCaseModel):\n    \"\"\"An event the forwarder will serialize and send over the network\"\"\"\n\n    origin_idx: int = Field(ge=0)\n    origin: NodeId\n    session: SessionId\n    event: Event\n\n\nclass LocalForwarderEvent(CamelCaseModel):\n    \"\"\"An event the forwarder will serialize and send over the network\"\"\"\n\n    origin_idx: int = Field(ge=0)\n    origin: SystemId\n    session: SessionId\n    event: Event\n"
  },
  {
    "path": "src/exo/shared/types/memory.py",
    "content": "from math import ceil\nfrom typing import Self, overload\n\nfrom exo.utils.pydantic_ext import FrozenModel\n\n\nclass Memory(FrozenModel):\n    in_bytes: int = 0\n\n    @classmethod\n    def from_bytes(cls, val: int) -> Self:\n        \"\"\"Construct a new Memory object from a number of bytes\"\"\"\n        return cls(in_bytes=val)\n\n    @property\n    def in_kb(self) -> int:\n        \"\"\"The approximate kilobytes this memory represents, rounded up. Setting this property rounds to the nearest byte.\"\"\"\n        return ceil(self.in_bytes / 1024)\n\n    @in_kb.setter\n    def in_kb(self, val: int):\n        \"\"\"Set this memorys value in kilobytes.\"\"\"\n        self.in_bytes = val * 1024\n\n    @classmethod\n    def from_kb(cls, val: int) -> Self:\n        \"\"\"Construct a new Memory object from a number of kilobytes\"\"\"\n        return cls(in_bytes=val * 1024)\n\n    @classmethod\n    def from_float_kb(cls, val: float) -> Self:\n        \"\"\"Construct a new Memory object from a number of kilobytes, rounding where appropriate\"\"\"\n        return cls(in_bytes=round(val * 1024))\n\n    @property\n    def in_mb(self) -> int:\n        \"\"\"The approximate megabytes this memory represents, rounded to nearest MB. Setting this property rounds to the nearest byte.\"\"\"\n        return round(self.in_bytes / (1024**2))\n\n    @in_mb.setter\n    def in_mb(self, val: int):\n        \"\"\"Set the megabytes for this memory.\"\"\"\n        self.in_bytes = val * (1024**2)\n\n    @property\n    def in_float_mb(self) -> float:\n        \"\"\"The megabytes this memory represents as a float. Setting this property rounds to the nearest byte.\"\"\"\n        return self.in_bytes / (1024**2)\n\n    @in_float_mb.setter\n    def in_float_mb(self, val: float):\n        \"\"\"Set the megabytes for this memory, rounded to the nearest byte.\"\"\"\n        self.in_bytes = round(val * (1024**2))\n\n    @classmethod\n    def from_mb(cls, val: float) -> Self:\n        \"\"\"Construct a new Memory object from a number of megabytes\"\"\"\n        return cls(in_bytes=round(val * (1024**2)))\n\n    @classmethod\n    def from_gb(cls, val: float) -> Self:\n        \"\"\"Construct a new Memory object from a number of megabytes\"\"\"\n        return cls(in_bytes=round(val * (1024**3)))\n\n    @property\n    def in_gb(self) -> float:\n        \"\"\"The approximate gigabytes this memory represents.\"\"\"\n        return self.in_bytes / (1024**3)\n\n    def __add__(self, other: object) -> \"Memory\":\n        if isinstance(other, Memory):\n            return Memory.from_bytes(self.in_bytes + other.in_bytes)\n        return NotImplemented\n\n    def __radd__(self, other: object) -> \"Memory\":\n        if other == 0:\n            return self\n        return NotImplemented\n\n    def __sub__(self, other: object) -> \"Memory\":\n        if isinstance(other, Memory):\n            return Memory.from_bytes(self.in_bytes - other.in_bytes)\n        return NotImplemented\n\n    def __mul__(self, other: int | float):\n        return Memory.from_bytes(round(self.in_bytes * other))\n\n    def __rmul__(self, other: int | float):\n        return self * other\n\n    @overload\n    def __truediv__(self, other: \"Memory\") -> float: ...\n    @overload\n    def __truediv__(self, other: int) -> \"Memory\": ...\n    @overload\n    def __truediv__(self, other: float) -> \"Memory\": ...\n    def __truediv__(self, other: object) -> \"Memory | float\":\n        if isinstance(other, Memory):\n            return self.in_bytes / other.in_bytes\n        if isinstance(other, (int, float)):\n            return Memory.from_bytes(round(self.in_bytes / other))\n        return NotImplemented\n\n    def __floordiv__(self, other: object) -> \"Memory\":\n        if isinstance(other, (int, float)):\n            return Memory.from_bytes(int(self.in_bytes // other))\n        return NotImplemented\n\n    def __lt__(self, other: object) -> bool:\n        if isinstance(other, Memory):\n            return self.in_bytes < other.in_bytes\n        return NotImplemented\n\n    def __le__(self, other: object) -> bool:\n        if isinstance(other, Memory):\n            return self.in_bytes <= other.in_bytes\n        return NotImplemented\n\n    def __gt__(self, other: object) -> bool:\n        if isinstance(other, Memory):\n            return self.in_bytes > other.in_bytes\n        return NotImplemented\n\n    def __ge__(self, other: object) -> bool:\n        if isinstance(other, Memory):\n            return self.in_bytes >= other.in_bytes\n        return NotImplemented\n\n    def __eq__(self, other: object) -> bool:\n        if isinstance(other, Memory):\n            return self.in_bytes == other.in_bytes\n        return NotImplemented\n\n    def __repr__(self) -> str:\n        return f\"Memory.from_bytes({self.in_bytes})\"\n\n    def __str__(self) -> str:\n        if self.in_gb > 2:\n            val = self.in_gb\n            unit = \"GiB\"\n        elif self.in_mb > 2:\n            val = self.in_mb\n            unit = \"MiB\"\n        elif self.in_kb > 3:\n            val = self.in_kb\n            unit = \"KiB\"\n        else:\n            val = self.in_bytes\n            unit = \"B\"\n\n        return f\"{val:.2f} {unit}\".rstrip(\"0\").rstrip(\".\") + f\" {unit}\"\n"
  },
  {
    "path": "src/exo/shared/types/mlx.py",
    "content": "\"\"\"Shared types for MLX-related functionality.\"\"\"\n\nfrom collections.abc import Sequence\n\nfrom mlx import core as mx\nfrom mlx import nn as nn\nfrom mlx_lm.models.cache import (\n    ArraysCache,\n    CacheList,\n    KVCache,\n    QuantizedKVCache,\n    RotatingKVCache,\n)\n\n# This list contains one cache entry per transformer layer\nKVCacheType = Sequence[\n    KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList\n]\n\n\n# Model is a wrapper function to fix the fact that mlx is not strongly typed in the same way that EXO is.\n# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function\nclass Model(nn.Module):\n    layers: list[nn.Module]\n\n    def __call__(\n        self,\n        x: mx.array,\n        cache: KVCacheType | None,\n        input_embeddings: mx.array | None = None,\n    ) -> mx.array: ...\n"
  },
  {
    "path": "src/exo/shared/types/multiaddr.py",
    "content": "import re\nfrom typing import ClassVar\n\nfrom pydantic import BaseModel, ConfigDict, computed_field, field_validator\n\n\nclass Multiaddr(BaseModel):\n    model_config = ConfigDict(frozen=True)\n    address: str\n\n    PATTERNS: ClassVar[list[str]] = [\n        r\"^/ip4/(\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3})(/tcp/(\\d{1,5}))?(/p2p/[A-Za-z0-9]+)?$\",\n        r\"^/ip6/([0-9a-fA-F:]+)(/tcp/(\\d{1,5}))?(/p2p/[A-Za-z0-9]+)?$\",\n        r\"^/dns[46]?/([a-zA-Z0-9.-]+)(/tcp/(\\d{1,5}))?(/p2p/[A-Za-z0-9]+)?$\",\n    ]\n\n    @field_validator(\"address\")\n    @classmethod\n    def validate_format(cls, v: str) -> str:\n        if not any(re.match(pattern, v) for pattern in cls.PATTERNS):\n            raise ValueError(\n                f\"Invalid multiaddr format: {v}. \"\n                \"Expected format like /ip4/127.0.0.1/tcp/4001 or /dns/example.com/tcp/443\"\n            )\n        return v\n\n    @computed_field\n    @property\n    def address_type(self) -> str:\n        for pattern in self.PATTERNS:\n            if re.match(pattern, self.address):\n                return pattern.split(\"/\")[1]\n        raise ValueError(f\"Invalid multiaddr format: {self.address}\")\n\n    @property\n    def ipv6_address(self) -> str:\n        match = re.match(r\"^/ip6/([0-9a-fA-F:]+)\", self.address)\n        if not match:\n            raise ValueError(\n                f\"Invalid multiaddr format: {self.address}. Expected format like /ip6/::1/tcp/4001\"\n            )\n        return match.group(1)\n\n    @property\n    def ipv4_address(self) -> str:\n        match = re.match(r\"^/ip4/(\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3})\", self.address)\n        if not match:\n            raise ValueError(\n                f\"Invalid multiaddr format: {self.address}. Expected format like /ip4/127.0.0.1/tcp/4001\"\n            )\n        return match.group(1)\n\n    @computed_field\n    @property\n    def ip_address(self) -> str:\n        return self.ipv4_address if self.address_type == \"ip4\" else self.ipv6_address\n\n    @computed_field\n    @property\n    def port(self) -> int:\n        match = re.search(r\"/tcp/(\\d{1,5})\", self.address)\n        if not match:\n            raise ValueError(\n                f\"Invalid multiaddr format: {self.address}. Expected format like /ip4/127.0.0.1/tcp/4001\"\n            )\n        return int(match.group(1))\n\n    def __str__(self) -> str:\n        return self.address\n"
  },
  {
    "path": "src/exo/shared/types/profiling.py",
    "content": "import shutil\nfrom collections.abc import Sequence\nfrom pathlib import Path\nfrom typing import Literal, Self\n\nimport psutil\n\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.thunderbolt import ThunderboltIdentifier\nfrom exo.utils.pydantic_ext import CamelCaseModel\n\n\nclass MemoryUsage(CamelCaseModel):\n    ram_total: Memory\n    ram_available: Memory\n    swap_total: Memory\n    swap_available: Memory\n\n    @classmethod\n    def from_bytes(\n        cls, *, ram_total: int, ram_available: int, swap_total: int, swap_available: int\n    ) -> Self:\n        return cls(\n            ram_total=Memory.from_bytes(ram_total),\n            ram_available=Memory.from_bytes(ram_available),\n            swap_total=Memory.from_bytes(swap_total),\n            swap_available=Memory.from_bytes(swap_available),\n        )\n\n    @classmethod\n    def from_psutil(cls, *, override_memory: int | None) -> Self:\n        vm = psutil.virtual_memory()\n        sm = psutil.swap_memory()\n\n        return cls.from_bytes(\n            ram_total=vm.total,\n            ram_available=vm.available if override_memory is None else override_memory,\n            swap_total=sm.total,\n            swap_available=sm.free,\n        )\n\n\nclass DiskUsage(CamelCaseModel):\n    \"\"\"Disk space usage for the models directory.\"\"\"\n\n    total: Memory\n    available: Memory\n\n    @classmethod\n    def from_path(cls, path: Path) -> Self:\n        \"\"\"Get disk usage stats for the partition containing path.\"\"\"\n        total, _used, free = shutil.disk_usage(path)\n        return cls(\n            total=Memory.from_bytes(total),\n            available=Memory.from_bytes(free),\n        )\n\n\nclass SystemPerformanceProfile(CamelCaseModel):\n    # TODO: flops_fp16: float\n\n    gpu_usage: float = 0.0\n    temp: float = 0.0\n    sys_power: float = 0.0\n    pcpu_usage: float = 0.0\n    ecpu_usage: float = 0.0\n\n\nInterfaceType = Literal[\"wifi\", \"ethernet\", \"maybe_ethernet\", \"thunderbolt\", \"unknown\"]\n\n\nclass NetworkInterfaceInfo(CamelCaseModel):\n    name: str\n    ip_address: str\n    interface_type: InterfaceType = \"unknown\"\n\n\nclass NodeIdentity(CamelCaseModel):\n    \"\"\"Static and slow-changing node identification data.\"\"\"\n\n    model_id: str = \"Unknown\"\n    chip_id: str = \"Unknown\"\n    friendly_name: str = \"Unknown\"\n    os_version: str = \"Unknown\"\n    os_build_version: str = \"Unknown\"\n\n\nclass NodeNetworkInfo(CamelCaseModel):\n    \"\"\"Network interface information for a node.\"\"\"\n\n    interfaces: Sequence[NetworkInterfaceInfo] = []\n\n\nclass NodeThunderboltInfo(CamelCaseModel):\n    \"\"\"Thunderbolt interface identifiers for a node.\"\"\"\n\n    interfaces: Sequence[ThunderboltIdentifier] = []\n\n\nclass NodeRdmaCtlStatus(CamelCaseModel):\n    \"\"\"Whether RDMA is enabled on this node (via rdma_ctl).\"\"\"\n\n    enabled: bool\n\n\nclass ThunderboltBridgeStatus(CamelCaseModel):\n    \"\"\"Whether the Thunderbolt Bridge network service is enabled on this node.\"\"\"\n\n    enabled: bool\n    exists: bool\n    service_name: str | None = None\n"
  },
  {
    "path": "src/exo/shared/types/state.py",
    "content": "from collections.abc import Mapping, Sequence\nfrom datetime import datetime\nfrom typing import Any, cast\n\nfrom pydantic import ConfigDict, Field, field_serializer, field_validator\nfrom pydantic.alias_generators import to_camel\n\nfrom exo.shared.topology import Topology, TopologySnapshot\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.profiling import (\n    DiskUsage,\n    MemoryUsage,\n    NodeIdentity,\n    NodeNetworkInfo,\n    NodeRdmaCtlStatus,\n    NodeThunderboltInfo,\n    SystemPerformanceProfile,\n    ThunderboltBridgeStatus,\n)\nfrom exo.shared.types.tasks import Task, TaskId\nfrom exo.shared.types.worker.downloads import DownloadProgress\nfrom exo.shared.types.worker.instances import Instance, InstanceId\nfrom exo.shared.types.worker.runners import RunnerId, RunnerStatus\nfrom exo.utils.pydantic_ext import CamelCaseModel\n\n\nclass State(CamelCaseModel):\n    \"\"\"Global system state.\n\n    The :class:`Topology` instance is encoded/decoded via an immutable\n    :class:`~shared.topology.TopologySnapshot` to ensure compatibility with\n    standard JSON serialisation.\n    \"\"\"\n\n    model_config = ConfigDict(\n        alias_generator=to_camel,\n        validate_by_name=True,\n        extra=\"forbid\",\n        # I want to reenable this ASAP, but it's causing an issue with TaskStatus\n        strict=True,\n        arbitrary_types_allowed=True,\n    )\n    instances: Mapping[InstanceId, Instance] = {}\n    runners: Mapping[RunnerId, RunnerStatus] = {}\n    downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}\n    tasks: Mapping[TaskId, Task] = {}\n    last_seen: Mapping[NodeId, datetime] = {}\n    topology: Topology = Field(default_factory=Topology)\n    last_event_applied_idx: int = Field(default=-1, ge=-1)\n\n    # Granular node state mappings (update independently at different frequencies)\n    node_identities: Mapping[NodeId, NodeIdentity] = {}\n    node_memory: Mapping[NodeId, MemoryUsage] = {}\n    node_disk: Mapping[NodeId, DiskUsage] = {}\n    node_system: Mapping[NodeId, SystemPerformanceProfile] = {}\n    node_network: Mapping[NodeId, NodeNetworkInfo] = {}\n    node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}\n    node_thunderbolt_bridge: Mapping[NodeId, ThunderboltBridgeStatus] = {}\n    node_rdma_ctl: Mapping[NodeId, NodeRdmaCtlStatus] = {}\n\n    # Detected cycles where all nodes have Thunderbolt bridge enabled (>2 nodes)\n    thunderbolt_bridge_cycles: Sequence[Sequence[NodeId]] = []\n\n    @field_serializer(\"topology\", mode=\"plain\")\n    def _encode_topology(self, value: Topology) -> TopologySnapshot:\n        return value.to_snapshot()\n\n    @field_validator(\"topology\", mode=\"before\")\n    @classmethod\n    def _deserialize_topology(cls, value: object) -> Topology:  # noqa: D401 – Pydantic validator signature\n        \"\"\"Convert an incoming *value* into a :class:`Topology` instance.\n\n        Accepts either an already constructed :class:`Topology` or a mapping\n        representing :class:`~shared.topology.TopologySnapshot`.\n        \"\"\"\n\n        if isinstance(value, Topology):\n            return value\n\n        if isinstance(value, Mapping):  # likely a snapshot-dict coming from JSON\n            snapshot = TopologySnapshot(**cast(dict[str, Any], value))  # type: ignore[arg-type]\n            return Topology.from_snapshot(snapshot)\n\n        raise TypeError(\"Invalid representation for Topology field in State\")\n"
  },
  {
    "path": "src/exo/shared/types/tasks.py",
    "content": "from enum import Enum\n\nfrom pydantic import Field\n\nfrom exo.api.types import (\n    ImageEditsTaskParams,\n    ImageGenerationTaskParams,\n)\nfrom exo.shared.types.common import CommandId, Id\nfrom exo.shared.types.text_generation import TextGenerationTaskParams\nfrom exo.shared.types.worker.instances import BoundInstance, InstanceId\nfrom exo.shared.types.worker.runners import RunnerId\nfrom exo.shared.types.worker.shards import ShardMetadata\nfrom exo.utils.pydantic_ext import TaggedModel\n\n\nclass TaskId(Id):\n    pass\n\n\nCANCEL_ALL_TASKS = TaskId(\"CANCEL_ALL_TASKS\")\n\n\nclass TaskStatus(str, Enum):\n    Pending = \"Pending\"\n    Running = \"Running\"\n    Complete = \"Complete\"\n    TimedOut = \"TimedOut\"\n    Failed = \"Failed\"\n    Cancelled = \"Cancelled\"\n\n\nclass BaseTask(TaggedModel):\n    task_id: TaskId = Field(default_factory=TaskId)\n    task_status: TaskStatus = Field(default=TaskStatus.Pending)\n    instance_id: InstanceId\n\n\nclass CreateRunner(BaseTask):  # emitted by Worker\n    bound_instance: BoundInstance\n\n\nclass DownloadModel(BaseTask):  # emitted by Worker\n    shard_metadata: ShardMetadata\n\n\nclass LoadModel(BaseTask):  # emitted by Worker\n    pass\n\n\nclass ConnectToGroup(BaseTask):  # emitted by Worker\n    pass\n\n\nclass StartWarmup(BaseTask):  # emitted by Worker\n    pass\n\n\nclass TextGeneration(BaseTask):  # emitted by Master\n    command_id: CommandId\n    task_params: TextGenerationTaskParams\n\n    error_type: str | None = Field(default=None)\n    error_message: str | None = Field(default=None)\n\n\nclass CancelTask(BaseTask):\n    cancelled_task_id: TaskId\n    runner_id: RunnerId\n\n\nclass ImageGeneration(BaseTask):  # emitted by Master\n    command_id: CommandId\n    task_params: ImageGenerationTaskParams\n\n    error_type: str | None = Field(default=None)\n    error_message: str | None = Field(default=None)\n\n\nclass ImageEdits(BaseTask):  # emitted by Master\n    command_id: CommandId\n    task_params: ImageEditsTaskParams\n\n    error_type: str | None = Field(default=None)\n    error_message: str | None = Field(default=None)\n\n\nclass Shutdown(BaseTask):  # emitted by Worker\n    runner_id: RunnerId\n\n\nTask = (\n    CreateRunner\n    | DownloadModel\n    | ConnectToGroup\n    | LoadModel\n    | StartWarmup\n    | TextGeneration\n    | CancelTask\n    | ImageGeneration\n    | ImageEdits\n    | Shutdown\n)\n"
  },
  {
    "path": "src/exo/shared/types/text_generation.py",
    "content": "\"\"\"Canonical internal type for text generation task parameters.\n\nAll external API formats (Chat Completions, Claude Messages, OpenAI Responses)\nare converted to TextGenerationTaskParams at the API boundary via adapters.\n\"\"\"\n\nfrom typing import Any, Literal\n\nfrom pydantic import BaseModel\n\nfrom exo.shared.types.common import ModelId\n\nMessageRole = Literal[\"user\", \"assistant\", \"system\", \"developer\"]\nReasoningEffort = Literal[\"none\", \"minimal\", \"low\", \"medium\", \"high\", \"xhigh\"]\n\n\ndef resolve_reasoning_params(\n    reasoning_effort: ReasoningEffort | None,\n    enable_thinking: bool | None,\n) -> tuple[ReasoningEffort | None, bool | None]:\n    \"\"\"\n    enable_thinking=True  -> reasoning_effort=\"medium\"\n    enable_thinking=False -> reasoning_effort=\"none\"\n    reasoning_effort=\"none\" -> enable_thinking=False\n    reasoning_effort=<anything else> -> enable_thinking=True\n    \"\"\"\n    resolved_effort: ReasoningEffort | None = reasoning_effort\n    resolved_thinking: bool | None = enable_thinking\n\n    if reasoning_effort is None and enable_thinking is not None:\n        resolved_effort = \"medium\" if enable_thinking else \"none\"\n\n    if enable_thinking is None and reasoning_effort is not None:\n        resolved_thinking = reasoning_effort != \"none\"\n\n    return resolved_effort, resolved_thinking\n\n\nclass InputMessage(BaseModel, frozen=True):\n    \"\"\"Internal message for text generation pipelines.\"\"\"\n\n    role: MessageRole\n    content: str\n\n\nclass TextGenerationTaskParams(BaseModel, frozen=True):\n    \"\"\"Canonical internal task params for text generation.\n\n    Every API adapter converts its wire type into this before handing\n    off to the master/worker pipeline.\n    \"\"\"\n\n    model: ModelId\n    input: list[InputMessage]\n    instructions: str | None = None\n    max_output_tokens: int | None = None\n    temperature: float | None = None\n    top_p: float | None = None\n    stream: bool = False\n    tools: list[dict[str, Any]] | None = None\n    bench: bool = False\n    top_k: int | None = None\n    stop: str | list[str] | None = None\n    seed: int | None = None\n    chat_template_messages: list[dict[str, Any]] | None = None\n    reasoning_effort: ReasoningEffort | None = None\n    enable_thinking: bool | None = None\n    logprobs: bool = False\n    top_logprobs: int | None = None\n    min_p: float | None = None\n    repetition_penalty: float | None = None\n    repetition_context_size: int | None = None\n"
  },
  {
    "path": "src/exo/shared/types/thunderbolt.py",
    "content": "import anyio\nfrom pydantic import BaseModel, Field\n\nfrom exo.utils.pydantic_ext import CamelCaseModel\n\n\nclass ThunderboltConnection(CamelCaseModel):\n    source_uuid: str\n    sink_uuid: str\n\n\nclass ThunderboltIdentifier(CamelCaseModel):\n    rdma_interface: str\n    domain_uuid: str\n    link_speed: str = \"\"\n\n\n## Intentionally minimal, only collecting data we care about - there's a lot more\n\n\nclass _ReceptacleTag(BaseModel, extra=\"ignore\"):\n    receptacle_id_key: str | None = None\n    current_speed_key: str | None = None\n\n\nclass _ConnectivityItem(BaseModel, extra=\"ignore\"):\n    domain_uuid_key: str | None = None\n\n\nclass ThunderboltConnectivityData(BaseModel, extra=\"ignore\"):\n    domain_uuid_key: str | None = None\n    items: list[_ConnectivityItem] | None = Field(None, alias=\"_items\")\n    receptacle_1_tag: _ReceptacleTag | None = None\n\n    def ident(self, ifaces: dict[str, str]) -> ThunderboltIdentifier | None:\n        if (\n            self.domain_uuid_key is None\n            or self.receptacle_1_tag is None\n            or self.receptacle_1_tag.receptacle_id_key is None\n        ):\n            return\n        tag = f\"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}\"\n        assert tag in ifaces  # doesn't need to be an assertion but im confident\n        # if tag not in ifaces: return None\n        iface = f\"rdma_{ifaces[tag]}\"\n        return ThunderboltIdentifier(\n            rdma_interface=iface,\n            domain_uuid=self.domain_uuid_key,\n            link_speed=self.receptacle_1_tag.current_speed_key or \"\",\n        )\n\n    def conn(self) -> ThunderboltConnection | None:\n        if self.domain_uuid_key is None or self.items is None:\n            return\n\n        sink_key = next(\n            (\n                item.domain_uuid_key\n                for item in self.items\n                if item.domain_uuid_key is not None\n            ),\n            None,\n        )\n        if sink_key is None:\n            return None\n\n        return ThunderboltConnection(\n            source_uuid=self.domain_uuid_key, sink_uuid=sink_key\n        )\n\n\nclass ThunderboltConnectivity(BaseModel, extra=\"ignore\"):\n    SPThunderboltDataType: list[ThunderboltConnectivityData] = []\n\n    @classmethod\n    async def gather(cls) -> list[ThunderboltConnectivityData] | None:\n        proc = await anyio.run_process(\n            [\"system_profiler\", \"SPThunderboltDataType\", \"-json\"], check=False\n        )\n        if proc.returncode != 0:\n            return None\n        # Saving you from PascalCase while avoiding too much pydantic\n        return ThunderboltConnectivity.model_validate_json(\n            proc.stdout\n        ).SPThunderboltDataType\n"
  },
  {
    "path": "src/exo/shared/types/topology.py",
    "content": "from collections.abc import Iterator\nfrom dataclasses import dataclass\n\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.multiaddr import Multiaddr\nfrom exo.utils.pydantic_ext import FrozenModel\n\n\n@dataclass(frozen=True)\nclass Cycle:\n    node_ids: list[NodeId]\n\n    def __len__(self) -> int:\n        return self.node_ids.__len__()\n\n    def __iter__(self) -> Iterator[NodeId]:\n        return self.node_ids.__iter__()\n\n\nclass RDMAConnection(FrozenModel):\n    source_rdma_iface: str\n    sink_rdma_iface: str\n\n\nclass SocketConnection(FrozenModel):\n    sink_multiaddr: Multiaddr\n\n    def __hash__(self):\n        return hash(self.sink_multiaddr.ip_address)\n\n\nclass Connection(FrozenModel):\n    source: NodeId\n    sink: NodeId\n    edge: RDMAConnection | SocketConnection\n"
  },
  {
    "path": "src/exo/shared/types/worker/downloads.py",
    "content": "from datetime import timedelta\nfrom typing import Literal\n\nfrom pydantic import BaseModel, ConfigDict, Field, PositiveInt\n\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.worker.shards import ShardMetadata\nfrom exo.utils.pydantic_ext import CamelCaseModel, TaggedModel\n\n\nclass DownloadProgressData(CamelCaseModel):\n    total: Memory\n    downloaded: Memory\n    downloaded_this_session: Memory\n\n    completed_files: int\n    total_files: int\n\n    speed: float\n    eta_ms: int\n\n    files: dict[str, \"DownloadProgressData\"]\n\n\nclass BaseDownloadProgress(TaggedModel):\n    node_id: NodeId\n    shard_metadata: ShardMetadata\n    model_directory: str = \"\"\n\n\nclass DownloadPending(BaseDownloadProgress):\n    downloaded: Memory = Memory()\n    total: Memory = Memory()\n\n\nclass DownloadCompleted(BaseDownloadProgress):\n    total: Memory\n    read_only: bool = False\n\n\nclass DownloadFailed(BaseDownloadProgress):\n    error_message: str\n\n\nclass DownloadOngoing(BaseDownloadProgress):\n    download_progress: DownloadProgressData\n\n\nDownloadProgress = (\n    DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing\n)\n\n\nclass ModelSafetensorsIndexMetadata(BaseModel):\n    total_size: PositiveInt\n\n\nclass ModelSafetensorsIndex(BaseModel):\n    metadata: ModelSafetensorsIndexMetadata | None\n    weight_map: dict[str, str]\n\n\nclass FileListEntry(BaseModel):\n    type: Literal[\"file\", \"directory\"]\n    path: str\n    size: int | None = None\n\n\nclass RepoFileDownloadProgress(BaseModel):\n    repo_id: str\n    repo_revision: str\n    file_path: str\n    downloaded: Memory\n    downloaded_this_session: Memory\n    total: Memory\n    speed: float\n    eta: timedelta\n    status: Literal[\"not_started\", \"in_progress\", \"complete\"]\n    start_time: float\n\n    model_config = ConfigDict(frozen=True)\n\n\nclass RepoDownloadProgress(BaseModel):\n    repo_id: str\n    repo_revision: str\n    shard: ShardMetadata\n    completed_files: int\n    total_files: int\n    downloaded: Memory\n    downloaded_this_session: Memory\n    total: Memory\n    overall_speed: float\n    overall_eta: timedelta\n    status: Literal[\"not_started\", \"in_progress\", \"complete\"]\n    file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)\n\n    model_config = ConfigDict(frozen=True)\n"
  },
  {
    "path": "src/exo/shared/types/worker/instances.py",
    "content": "from enum import Enum\n\nfrom pydantic import model_validator\n\nfrom exo.shared.models.model_cards import ModelTask\nfrom exo.shared.types.common import Host, Id, NodeId\nfrom exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata\nfrom exo.utils.pydantic_ext import CamelCaseModel, TaggedModel\n\n\nclass InstanceId(Id):\n    pass\n\n\nclass InstanceMeta(str, Enum):\n    MlxRing = \"MlxRing\"\n    MlxJaccl = \"MlxJaccl\"\n\n\nclass BaseInstance(TaggedModel):\n    instance_id: InstanceId\n    shard_assignments: ShardAssignments\n\n    def shard(self, runner_id: RunnerId) -> ShardMetadata | None:\n        return self.shard_assignments.runner_to_shard.get(runner_id, None)\n\n\nclass MlxRingInstance(BaseInstance):\n    hosts_by_node: dict[NodeId, list[Host]]\n    ephemeral_port: int\n\n\nclass MlxJacclInstance(BaseInstance):\n    jaccl_devices: list[list[str | None]]\n    jaccl_coordinators: dict[NodeId, str]\n\n\n# TODO: Single node instance\nInstance = MlxRingInstance | MlxJacclInstance\n\n\nclass BoundInstance(CamelCaseModel):\n    instance: Instance\n    bound_runner_id: RunnerId\n    bound_node_id: NodeId\n\n    @property\n    def bound_shard(self) -> ShardMetadata:\n        shard = self.instance.shard(self.bound_runner_id)\n        assert shard is not None\n        return shard\n\n    @property\n    def is_image_model(self) -> bool:\n        return (\n            ModelTask.TextToImage in self.bound_shard.model_card.tasks\n            or ModelTask.ImageToImage in self.bound_shard.model_card.tasks\n        )\n\n    @model_validator(mode=\"after\")\n    def validate_shard_exists(self) -> \"BoundInstance\":\n        assert (\n            self.bound_runner_id in self.instance.shard_assignments.runner_to_shard\n        ), (\n            \"Bound Instance must be constructed with a runner_id that is in the instances assigned shards\"\n        )\n        return self\n"
  },
  {
    "path": "src/exo/shared/types/worker/runner_response.py",
    "content": "from collections.abc import Generator\nfrom typing import Any, Literal\n\nfrom exo.api.types import (\n    FinishReason,\n    GenerationStats,\n    ImageGenerationStats,\n    ToolCallItem,\n    TopLogprobItem,\n    Usage,\n)\nfrom exo.utils.pydantic_ext import TaggedModel\n\n\nclass BaseRunnerResponse(TaggedModel):\n    pass\n\n\nclass TokenizedResponse(BaseRunnerResponse):\n    prompt_tokens: int\n\n\nclass GenerationResponse(BaseRunnerResponse):\n    text: str\n    token: int\n    logprob: float | None = None\n    top_logprobs: list[TopLogprobItem] | None = None\n    finish_reason: FinishReason | None = None\n    stats: GenerationStats | None = None\n    usage: Usage | None\n    is_thinking: bool = False\n\n\nclass ImageGenerationResponse(BaseRunnerResponse):\n    image_data: bytes\n    format: Literal[\"png\", \"jpeg\", \"webp\"] = \"png\"\n    stats: ImageGenerationStats | None = None\n    image_index: int = 0\n\n    def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:\n        for name, value in super().__repr_args__():  # pyright: ignore[reportAny]\n            if name == \"image_data\":\n                yield name, f\"<{len(self.image_data)} bytes>\"\n            elif name is not None:\n                yield name, value\n\n\nclass PartialImageResponse(BaseRunnerResponse):\n    image_data: bytes\n    format: Literal[\"png\", \"jpeg\", \"webp\"] = \"png\"\n    partial_index: int\n    total_partials: int\n    image_index: int = 0\n\n    def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:\n        for name, value in super().__repr_args__():  # pyright: ignore[reportAny]\n            if name == \"image_data\":\n                yield name, f\"<{len(self.image_data)} bytes>\"\n            elif name is not None:\n                yield name, value\n\n\nclass ToolCallResponse(BaseRunnerResponse):\n    tool_calls: list[ToolCallItem]\n    usage: Usage | None\n    stats: GenerationStats | None = None\n\n\nclass FinishedResponse(BaseRunnerResponse):\n    pass\n\n\nclass PrefillProgressResponse(BaseRunnerResponse):\n    processed_tokens: int\n    total_tokens: int\n"
  },
  {
    "path": "src/exo/shared/types/worker/runners.py",
    "content": "from collections.abc import Mapping\n\nfrom pydantic import model_validator\n\nfrom exo.shared.models.model_cards import ModelId\nfrom exo.shared.types.common import Id, NodeId\nfrom exo.shared.types.worker.shards import ShardMetadata\nfrom exo.utils.pydantic_ext import CamelCaseModel, TaggedModel\n\n\nclass RunnerId(Id):\n    pass\n\n\nclass RunnerError(Exception):\n    pass\n\n\nclass BaseRunnerStatus(TaggedModel):\n    def is_running(self):\n        return isinstance(self, RunnerRunning)\n\n\nclass RunnerIdle(BaseRunnerStatus):\n    pass\n\n\nclass RunnerConnecting(BaseRunnerStatus):\n    pass\n\n\nclass RunnerConnected(BaseRunnerStatus):\n    pass\n\n\nclass RunnerLoading(BaseRunnerStatus):\n    layers_loaded: int = 0\n    total_layers: int = 0\n\n\nclass RunnerLoaded(BaseRunnerStatus):\n    pass\n\n\nclass RunnerWarmingUp(BaseRunnerStatus):\n    pass\n\n\nclass RunnerReady(BaseRunnerStatus):\n    pass\n\n\nclass RunnerRunning(BaseRunnerStatus):\n    pass\n\n\nclass RunnerShuttingDown(BaseRunnerStatus):\n    pass\n\n\nclass RunnerShutdown(BaseRunnerStatus):\n    pass\n\n\nclass RunnerFailed(BaseRunnerStatus):\n    error_message: str | None = None\n\n\nRunnerStatus = (\n    RunnerIdle\n    | RunnerConnecting\n    | RunnerConnected\n    | RunnerLoading\n    | RunnerLoaded\n    | RunnerWarmingUp\n    | RunnerReady\n    | RunnerRunning\n    | RunnerShuttingDown\n    | RunnerShutdown\n    | RunnerFailed\n)\n\n\nclass ShardAssignments(CamelCaseModel):\n    model_id: ModelId\n    runner_to_shard: Mapping[RunnerId, ShardMetadata]\n    node_to_runner: Mapping[NodeId, RunnerId]\n\n    @model_validator(mode=\"after\")\n    def validate_runners_exist(self) -> \"ShardAssignments\":\n        for runner_id in self.node_to_runner.values():\n            if runner_id not in self.runner_to_shard:\n                raise ValueError(\n                    f\"Runner {runner_id} in node_to_runner does not exist in runner_to_shard\"\n                )\n        return self\n"
  },
  {
    "path": "src/exo/shared/types/worker/shards.py",
    "content": "from enum import Enum\nfrom typing import TypeAlias, final\n\nfrom pydantic import Field\n\nfrom exo.shared.models.model_cards import ModelCard\nfrom exo.utils.pydantic_ext import TaggedModel\n\n\nclass Sharding(str, Enum):\n    Tensor = \"Tensor\"\n    Pipeline = \"Pipeline\"\n\n\nclass BaseShardMetadata(TaggedModel):\n    \"\"\"\n    Defines a specific shard of the model that is ready to be run on a device.\n    Replaces previous `Shard` object.\n    \"\"\"\n\n    model_card: ModelCard\n    device_rank: int\n    world_size: int\n\n    # Error handling; equivalent to monkey-patch, but we can't monkey-patch runner.py\n    # This is kinda annoying because it allocates memory in the ShardMetadata object. Can be rethought after Shanghai.\n    immediate_exception: bool = False\n    should_timeout: float | None = None\n\n    start_layer: int = Field(ge=0)\n    end_layer: int = Field(ge=0)\n    n_layers: int = Field(ge=0)\n\n    @property\n    def is_first_layer(self) -> bool:\n        return self.start_layer == 0\n\n    @property\n    def is_last_layer(self) -> bool:\n        return self.end_layer == self.n_layers\n\n    def __hash__(self) -> int:\n        return hash(\n            (\n                self.model_card.model_id,\n                self.start_layer,\n                self.end_layer,\n                self.n_layers,\n                self.device_rank,\n                self.world_size,\n            )\n        )\n\n\n@final\nclass PipelineShardMetadata(BaseShardMetadata):\n    \"\"\"\n    Pipeline parallelism shard meta.\n\n    Layers are represented as a half-open interval [start_layer, end_layer),\n    where start_layer is inclusive and end_layer is exclusive.\n    \"\"\"\n\n\n@final\nclass CfgShardMetadata(BaseShardMetadata):\n    \"\"\"Shard metadata for CFG-parallel image generation models.\"\"\"\n\n    cfg_rank: int  # 0 = positive branch, 1 = negative branch\n    cfg_world_size: int = 2\n\n    # Pipeline-relative coordinates (computed at placement time)\n    pipeline_rank: int  # rank within the pipeline group (0, 1, 2, ...)\n    pipeline_world_size: int  # number of nodes per pipeline group\n\n\n@final\nclass TensorShardMetadata(BaseShardMetadata):\n    pass\n\n\nShardMetadata: TypeAlias = (\n    PipelineShardMetadata | CfgShardMetadata | TensorShardMetadata\n)\n"
  },
  {
    "path": "src/exo/utils/__init__.py",
    "content": "from typing import Any, Type\n\nfrom .phantom import PhantomData\n\n\ndef ensure_type[T](obj: Any, expected_type: Type[T]) -> T:  # type: ignore\n    if not isinstance(obj, expected_type):\n        raise TypeError(f\"Expected {expected_type}, got {type(obj)}\")  # type: ignore\n    return obj\n\n\ndef todo[T](\n    msg: str = \"This code has not been implemented yet.\",\n    _phantom: PhantomData[T] = None,\n) -> T:\n    raise NotImplementedError(msg)\n"
  },
  {
    "path": "src/exo/utils/banner.py",
    "content": "import logging\nimport os\nimport sys\nimport webbrowser\n\nfrom exo.shared.constants import EXO_CONFIG_HOME\n\nlogger = logging.getLogger(__name__)\n\n_FIRST_RUN_MARKER = EXO_CONFIG_HOME / \".dashboard_opened\"\n\n\ndef _is_first_run() -> bool:\n    return not _FIRST_RUN_MARKER.exists()\n\n\ndef _mark_first_run_done() -> None:\n    _FIRST_RUN_MARKER.parent.mkdir(parents=True, exist_ok=True)\n    _FIRST_RUN_MARKER.touch()\n\n\ndef print_startup_banner(port: int) -> None:\n    dashboard_url = f\"http://localhost:{port}\"\n    first_run = _is_first_run()\n    banner = f\"\"\"\n╔═══════════════════════════════════════════════════════════════════════╗\n║                                                                       ║\n║   ███████╗██╗  ██╗ ██████╗                                            ║\n║   ██╔════╝╚██╗██╔╝██╔═══██╗                                           ║\n║   █████╗   ╚███╔╝ ██║   ██║                                           ║\n║   ██╔══╝   ██╔██╗ ██║   ██║                                           ║\n║   ███████╗██╔╝ ██╗╚██████╔╝                                           ║\n║   ╚══════╝╚═╝  ╚═╝ ╚═════╝                                            ║\n║                                                                       ║\n║   Distributed AI Inference Cluster                                    ║\n║                                                                       ║\n╚═══════════════════════════════════════════════════════════════════════╝\n\n╔═══════════════════════════════════════════════════════════════════════╗\n║                                                                       ║\n║  🌐 Dashboard & API Ready                                             ║\n║                                                                       ║\n║  {dashboard_url}{\" \" * (69 - len(dashboard_url))}║\n║                                                                       ║\n║  Click the URL above to open the dashboard in your browser            ║\n║                                                                       ║\n╚═══════════════════════════════════════════════════════════════════════╝\n\n\"\"\"\n\n    print(banner, file=sys.stderr)\n\n    if first_run:\n        # Skip browser open when running inside the native macOS app —\n        # FirstLaunchPopout.swift handles the auto-open with a countdown.\n        if not os.environ.get(\"EXO_RUNTIME_DIR\"):\n            try:\n                webbrowser.open(dashboard_url)\n                logger.info(\"First run detected — opening dashboard in browser\")\n            except Exception:\n                logger.debug(\"Could not auto-open browser\", exc_info=True)\n        _mark_first_run_done()\n"
  },
  {
    "path": "src/exo/utils/channels.py",
    "content": "import contextlib\nimport multiprocessing as mp\nfrom dataclasses import dataclass, field\nfrom math import inf\nfrom multiprocessing.synchronize import Event\nfrom queue import Empty, Full\nfrom types import TracebackType\nfrom typing import Any, Self\n\nfrom anyio import (\n    CapacityLimiter,\n    ClosedResourceError,\n    EndOfStream,\n    WouldBlock,\n    to_thread,\n)\nfrom anyio.streams.memory import (\n    MemoryObjectReceiveStream as AnyioReceiver,\n)\nfrom anyio.streams.memory import (\n    MemoryObjectSendStream as AnyioSender,\n)\nfrom anyio.streams.memory import (\n    MemoryObjectStreamState as AnyioState,\n)\n\n\nclass Sender[T](AnyioSender[T]):\n    def clone(self) -> \"Sender[T]\":\n        if self._closed:\n            raise ClosedResourceError\n        return Sender(_state=self._state)\n\n    def clone_receiver(self) -> \"Receiver[T]\":\n        \"\"\"Constructs a Receiver using a Senders shared state - similar to calling Receiver.clone() without needing the receiver\"\"\"\n        if self._closed:\n            raise ClosedResourceError\n        return Receiver(_state=self._state)\n\n\nclass Receiver[T](AnyioReceiver[T]):\n    def clone(self) -> \"Receiver[T]\":\n        if self._closed:\n            raise ClosedResourceError\n        return Receiver(_state=self._state)\n\n    def clone_sender(self) -> Sender[T]:\n        \"\"\"Constructs a Sender using a Receivers shared state - similar to calling Sender.clone() without needing the sender\"\"\"\n        if self._closed:\n            raise ClosedResourceError\n        return Sender(_state=self._state)\n\n    def collect(self) -> list[T]:\n        \"\"\"Collect all currently available items from this receiver\"\"\"\n        out: list[T] = []\n        while True:\n            try:\n                item = self.receive_nowait()\n                out.append(item)\n            except WouldBlock:\n                break\n        return out\n\n    async def receive_at_least(self, n: int) -> list[T]:\n        out: list[T] = []\n        out.append(await self.receive())\n        out.extend(self.collect())\n        while len(out) < n:\n            out.append(await self.receive())\n            out.extend(self.collect())\n        return out\n\n    def __enter__(self) -> Self:\n        return self\n\n\nclass _MpEndOfStream:\n    pass\n\n\nclass MpState[T]:\n    def __init__(self, max_buffer_size: float):\n        if max_buffer_size == inf:\n            max_buffer_size = 0\n        assert isinstance(max_buffer_size, int), (\n            \"State should only ever be constructed with an integer or math.inf size.\"\n        )\n\n        self.max_buffer_size: float = max_buffer_size\n        self.buffer: mp.Queue[T | _MpEndOfStream] = mp.Queue(max_buffer_size)\n        self.closed: Event = mp.Event()\n\n    def __getstate__(self):\n        d = self.__dict__.copy()\n        d.pop(\"__orig_class__\", None)\n        return d\n\n\n@dataclass(eq=False)\nclass MpSender[T]:\n    \"\"\"\n    An interprocess channel, mimicing the Anyio structure.\n    It should be noted that none of the clone methods are implemented for simplicity, for now.\n    \"\"\"\n\n    _state: MpState[T] = field()\n\n    def send_nowait(self, item: T) -> None:\n        if self._state.closed.is_set():\n            raise ClosedResourceError\n        try:\n            self._state.buffer.put(item, block=False)\n        except Full:\n            raise WouldBlock from None\n        except ValueError as e:\n            print(\"Unreachable code path - let me know!\")\n            raise ClosedResourceError from e\n\n    def send(self, item: T) -> None:\n        if self._state.closed.is_set():\n            raise ClosedResourceError\n        try:\n            self.send_nowait(item)\n        except WouldBlock:\n            # put anyway, blocking\n            self._state.buffer.put(item, block=True)\n\n    async def send_async(self, item: T) -> None:\n        await to_thread.run_sync(\n            self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True\n        )\n\n    def close(self) -> None:\n        if not self._state.closed.is_set():\n            self._state.closed.set()\n        with contextlib.suppress(Exception):\n            self._state.buffer.put_nowait(_MpEndOfStream())\n        self._state.buffer.close()\n\n    # == unique to Mp channels ==\n    def join(self) -> None:\n        \"\"\"Ensure any queued messages are resolved before continuing\"\"\"\n        assert self._state.closed.is_set(), (\n            \"Mp channels must be closed before being joined\"\n        )\n        self._state.buffer.join_thread()\n\n    # == context manager support ==\n    def __enter__(self) -> Self:\n        return self\n\n    def __exit__(\n        self,\n        exc_type: type[BaseException] | None,\n        exc_val: BaseException | None,\n        exc_tb: TracebackType | None,\n    ) -> None:\n        self.close()\n\n    def __getstate__(self) -> dict[str, Any]:\n        d = self.__dict__.copy()\n        d.pop(\"__orig_class__\", None)\n        return d\n\n\n@dataclass(eq=False)\nclass MpReceiver[T]:\n    \"\"\"\n    An interprocess channel, mimicing the Anyio structure.\n    It should be noted that none of the clone methods are implemented for simplicity, for now.\n    \"\"\"\n\n    _state: MpState[T] = field()\n\n    def receive_nowait(self) -> T:\n        if self._state.closed.is_set():\n            raise ClosedResourceError\n\n        try:\n            item = self._state.buffer.get(block=False)\n            if isinstance(item, _MpEndOfStream):\n                self.close()\n                raise EndOfStream\n            return item\n        except Empty:\n            raise WouldBlock from None\n        except ValueError as e:\n            print(\"Unreachable code path - let me know!\")\n            raise ClosedResourceError from e\n\n    def receive(self) -> T:\n        try:\n            return self.receive_nowait()\n        except WouldBlock:\n            try:\n                item = self._state.buffer.get()\n            except (TypeError, OSError):\n                # Queue pipe can get closed while we are blocked on get().\n                # The underlying connection._handle becomes None, causing\n                # TypeError in read(handle, remaining).\n                raise ClosedResourceError from None\n            if isinstance(item, _MpEndOfStream):\n                self.close()\n                raise EndOfStream from None\n            return item\n\n    async def receive_async(self) -> T:\n        return await to_thread.run_sync(\n            self.receive, limiter=CapacityLimiter(1), abandon_on_cancel=True\n        )\n\n    def close(self) -> None:\n        if not self._state.closed.is_set():\n            self._state.closed.set()\n        with contextlib.suppress(Exception):\n            self._state.buffer.put_nowait(_MpEndOfStream())\n        self._state.buffer.close()\n\n    # == unique to Mp channels ==\n    def join(self) -> None:\n        \"\"\"Block until all enqueued messages are drained off our side of the buffer\"\"\"\n        assert self._state.closed.is_set(), (\n            \"Mp channels must be closed before being joined\"\n        )\n        self._state.buffer.join_thread()\n\n    # == iterator support ==\n    def __iter__(self) -> Self:\n        return self\n\n    def __next__(self) -> T:\n        try:\n            return self.receive()\n        except EndOfStream:\n            raise StopIteration from None\n\n    # == async iterator support ==\n    def __aiter__(self) -> Self:\n        return self\n\n    async def __anext__(self) -> T:\n        try:\n            return await self.receive_async()\n        except EndOfStream:\n            raise StopAsyncIteration from None\n\n    # == context manager support ==\n    def __enter__(self) -> Self:\n        return self\n\n    def __exit__(\n        self,\n        exc_type: type[BaseException] | None,\n        exc_val: BaseException | None,\n        exc_tb: TracebackType | None,\n    ) -> None:\n        self.close()\n\n    def collect(self) -> list[T]:\n        \"\"\"Collect all currently available items from this receiver\"\"\"\n        out: list[T] = []\n        while True:\n            try:\n                item = self.receive_nowait()\n                out.append(item)\n            except WouldBlock:\n                break\n        return out\n\n    def receive_at_least(self, n: int) -> list[T]:\n        out: list[T] = []\n        out.append(self.receive())\n        out.extend(self.collect())\n        while len(out) < n:\n            out.append(self.receive())\n            out.extend(self.collect())\n        return out\n\n    def __getstate__(self):\n        d = self.__dict__.copy()\n        d.pop(\"__orig_class__\", None)\n        return d\n\n\nclass channel[T]:  # noqa: N801\n    \"\"\"Create a pair of asynchronous channels for communicating within the same process\"\"\"\n\n    def __new__(cls, max_buffer_size: float = inf) -> tuple[Sender[T], Receiver[T]]:\n        if max_buffer_size != inf and not isinstance(max_buffer_size, int):\n            raise ValueError(\"max_buffer_size must be either an integer or math.inf\")\n        state = AnyioState[T](max_buffer_size)\n        return Sender(_state=state), Receiver(_state=state)\n\n\nclass mp_channel[T]:  # noqa: N801\n    \"\"\"Create a pair of synchronous channels for interprocess communication\"\"\"\n\n    # max buffer size uses math.inf to represent an unbounded queue, and 0 to represent a yet unimplemented \"unbuffered\" queue.\n    def __new__(cls, max_buffer_size: float = inf) -> tuple[MpSender[T], MpReceiver[T]]:\n        if (\n            max_buffer_size == 0\n            or max_buffer_size != inf\n            and not isinstance(max_buffer_size, int)\n        ):\n            raise ValueError(\n                \"max_buffer_size must be either an integer or math.inf. 0-sized buffers are not supported by multiprocessing\"\n            )\n        state = MpState[T](max_buffer_size)\n        return MpSender(_state=state), MpReceiver(_state=state)\n"
  },
  {
    "path": "src/exo/utils/dashboard_path.py",
    "content": "import sys\nfrom pathlib import Path\nfrom typing import cast\n\n\ndef find_resources() -> Path:\n    resources = _find_resources_in_repo() or _find_resources_in_bundle()\n    if resources is None:\n        raise FileNotFoundError(\n            \"Unable to locate resources. Did you clone the repo properly?\"\n        )\n    return resources\n\n\ndef _find_resources_in_repo() -> Path | None:\n    current_module = Path(__file__).resolve()\n    for parent in current_module.parents:\n        build = parent / \"resources\"\n        if build.is_dir():\n            return build\n    return None\n\n\ndef _find_resources_in_bundle() -> Path | None:\n    frozen_root = cast(str | None, getattr(sys, \"_MEIPASS\", None))\n    if frozen_root is None:\n        return None\n    candidate = Path(frozen_root) / \"resources\"\n    if candidate.is_dir():\n        return candidate\n    return None\n\n\ndef find_dashboard() -> Path:\n    dashboard = _find_dashboard_in_repo() or _find_dashboard_in_bundle()\n    if not dashboard:\n        raise FileNotFoundError(\n            \"Unable to locate dashboard assets - you probably forgot to run `cd dashboard && npm install && npm run build && cd ..`\"\n        )\n    return dashboard\n\n\ndef _find_dashboard_in_repo() -> Path | None:\n    current_module = Path(__file__).resolve()\n    for parent in current_module.parents:\n        build = parent / \"dashboard\" / \"build\"\n        if build.is_dir() and (build / \"index.html\").exists():\n            return build\n    return None\n\n\ndef _find_dashboard_in_bundle() -> Path | None:\n    frozen_root = cast(str | None, getattr(sys, \"_MEIPASS\", None))\n    if frozen_root is None:\n        return None\n    candidate = Path(frozen_root) / \"dashboard\"\n    if candidate.is_dir():\n        return candidate\n    return None\n"
  },
  {
    "path": "src/exo/utils/disk_event_log.py",
    "content": "import contextlib\nimport json\nfrom collections import OrderedDict\nfrom collections.abc import Iterator\nfrom datetime import datetime, timezone\nfrom io import BufferedRandom, BufferedReader\nfrom pathlib import Path\n\nimport msgspec\nimport zstandard\nfrom loguru import logger\nfrom pydantic import TypeAdapter\n\nfrom exo.shared.types.events import Event\n\n_EVENT_ADAPTER: TypeAdapter[Event] = TypeAdapter(Event)\n\n_HEADER_SIZE = 4  # uint32 big-endian\n_OFFSET_CACHE_SIZE = 128\n_MAX_ARCHIVES = 5\n\n\ndef _serialize_event(event: Event) -> bytes:\n    return msgspec.msgpack.encode(event.model_dump(mode=\"json\"))\n\n\ndef _deserialize_event(raw: bytes) -> Event:\n    # Decode msgpack into a Python dict, then re-encode as JSON for Pydantic.\n    # Pydantic's validate_json() uses JSON-mode coercion (e.g. string -> enum)\n    # even under strict=True, whereas validate_python() does not. Going through\n    # JSON is the only way to get correct round-trip deserialization without\n    # disabling strict mode or adding casts everywhere.\n    as_json = json.dumps(msgspec.msgpack.decode(raw, type=dict))\n    return _EVENT_ADAPTER.validate_json(as_json)\n\n\ndef _unpack_header(header: bytes) -> int:\n    return int.from_bytes(header, byteorder=\"big\")\n\n\ndef _skip_record(f: BufferedReader) -> bool:\n    \"\"\"Skip one length-prefixed record. Returns False on EOF.\"\"\"\n    header = f.read(_HEADER_SIZE)\n    if len(header) < _HEADER_SIZE:\n        return False\n    f.seek(_unpack_header(header), 1)\n    return True\n\n\ndef _read_record(f: BufferedReader) -> Event | None:\n    \"\"\"Read one length-prefixed record. Returns None on EOF.\"\"\"\n    header = f.read(_HEADER_SIZE)\n    if len(header) < _HEADER_SIZE:\n        return None\n    length = _unpack_header(header)\n    payload = f.read(length)\n    if len(payload) < length:\n        return None\n    return _deserialize_event(payload)\n\n\nclass DiskEventLog:\n    \"\"\"Append-only event log backed by a file on disk.\n\n    On-disk format: sequence of length-prefixed msgpack records.\n    Each record is [4-byte big-endian uint32 length][msgpack payload].\n\n    Uses a bounded LRU cache of event index → byte offset for efficient\n    random access without storing an offset per event.\n    \"\"\"\n\n    def __init__(self, directory: Path) -> None:\n        self._directory = directory\n        self._directory.mkdir(parents=True, exist_ok=True)\n        self._active_path = directory / \"events.bin\"\n        self._offset_cache: OrderedDict[int, int] = OrderedDict()\n        self._count: int = 0\n\n        # Rotate stale active file from a previous session/crash\n        if self._active_path.exists():\n            self._rotate(self._active_path, self._directory)\n\n        self._file: BufferedRandom = open(self._active_path, \"w+b\")  # noqa: SIM115\n\n    def _cache_offset(self, idx: int, offset: int) -> None:\n        self._offset_cache[idx] = offset\n        self._offset_cache.move_to_end(idx)\n        if len(self._offset_cache) > _OFFSET_CACHE_SIZE:\n            self._offset_cache.popitem(last=False)\n\n    def _seek_to(self, f: BufferedReader, target_idx: int) -> None:\n        \"\"\"Seek f to the byte offset of event target_idx, using cache or scanning forward.\"\"\"\n        if target_idx in self._offset_cache:\n            self._offset_cache.move_to_end(target_idx)\n            f.seek(self._offset_cache[target_idx])\n            return\n\n        # Find the highest cached index before target_idx\n        scan_from_idx = 0\n        scan_from_offset = 0\n        for cached_idx in self._offset_cache:\n            if cached_idx < target_idx:\n                scan_from_idx = cached_idx\n                scan_from_offset = self._offset_cache[cached_idx]\n\n        # Scan forward, skipping records\n        f.seek(scan_from_offset)\n        for _ in range(scan_from_idx, target_idx):\n            _skip_record(f)\n\n        self._cache_offset(target_idx, f.tell())\n\n    def append(self, event: Event) -> None:\n        packed = _serialize_event(event)\n        self._file.write(len(packed).to_bytes(_HEADER_SIZE, byteorder=\"big\"))\n        self._file.write(packed)\n        self._count += 1\n\n    def read_range(self, start: int, end: int) -> Iterator[Event]:\n        \"\"\"Yield events from index start (inclusive) to end (exclusive).\"\"\"\n        end = min(end, self._count)\n        if start < 0 or end < 0 or start >= end:\n            return\n\n        self._file.flush()\n        with open(self._active_path, \"rb\") as f:\n            self._seek_to(f, start)\n            for _ in range(end - start):\n                event = _read_record(f)\n                if event is None:\n                    break\n                yield event\n\n            # Cache where we ended up so the next sequential read is a hit\n            if end < self._count:\n                self._cache_offset(end, f.tell())\n\n    def read_all(self) -> Iterator[Event]:\n        \"\"\"Yield all events from the log one at a time.\"\"\"\n        if self._count == 0:\n            return\n        self._file.flush()\n        with open(self._active_path, \"rb\") as f:\n            for _ in range(self._count):\n                event = _read_record(f)\n                if event is None:\n                    break\n                yield event\n\n    def __len__(self) -> int:\n        return self._count\n\n    def close(self) -> None:\n        \"\"\"Close the file and rotate active file to compressed archive.\"\"\"\n        if self._file.closed:\n            return\n        self._file.close()\n        if self._active_path.exists() and self._count > 0:\n            self._rotate(self._active_path, self._directory)\n        elif self._active_path.exists():\n            self._active_path.unlink()\n\n    @staticmethod\n    def _rotate(source: Path, directory: Path) -> None:\n        \"\"\"Compress source into a timestamped archive.\n\n        Keeps at most ``_MAX_ARCHIVES`` compressed copies.  Oldest beyond\n        the limit are deleted.\n        \"\"\"\n        try:\n            stamp = datetime.now(timezone.utc).strftime(\"%Y-%m-%d_%H-%M-%S_%f\")\n            dest = directory / f\"events.{stamp}.bin.zst\"\n            compressor = zstandard.ZstdCompressor()\n            with open(source, \"rb\") as f_in, open(dest, \"wb\") as f_out:\n                compressor.copy_stream(f_in, f_out)\n            source.unlink()\n            logger.info(f\"Rotated event log: {source} -> {dest}\")\n\n            # Prune oldest archives beyond the limit\n            archives = sorted(directory.glob(\"events.*.bin.zst\"))\n            for old in archives[:-_MAX_ARCHIVES]:\n                old.unlink()\n        except Exception as e:\n            logger.opt(exception=e).warning(f\"Failed to rotate event log {source}\")\n            # Clean up the source even if compression fails\n            with contextlib.suppress(OSError):\n                source.unlink()\n"
  },
  {
    "path": "src/exo/utils/event_buffer.py",
    "content": "from loguru import logger\n\n\nclass OrderedBuffer[T]:\n    \"\"\"\n    A buffer that resequences events to ensure their ordering is preserved.\n    Currently this buffer doesn't raise any errors if an event is lost\n    This buffer is NOT thread safe, and is designed to only be polled from one\n    source at a time.\n    \"\"\"\n\n    def __init__(self):\n        self.store: dict[int, T] = {}\n        self.next_idx_to_release: int = 0\n\n    def ingest(self, idx: int, t: T):\n        \"\"\"Ingest a sequence into the buffer\"\"\"\n        logger.trace(f\"Ingested event {t}\")\n        if idx < self.next_idx_to_release:\n            return\n        if idx in self.store:\n            assert self.store[idx] == t, (\n                \"Received different messages with identical indices, probable race condition\"\n            )\n            return\n        self.store[idx] = t\n\n    def drain(self) -> list[T]:\n        \"\"\"Drain all available events from the buffer\"\"\"\n        ret: list[T] = []\n        while self.next_idx_to_release in self.store:\n            idx = self.next_idx_to_release\n            event = self.store.pop(idx)\n            ret.append(event)\n            self.next_idx_to_release += 1\n        logger.trace(f\"Releasing event {ret}\")\n        return ret\n\n    def drain_indexed(self) -> list[tuple[int, T]]:\n        \"\"\"Drain all available events from the buffer\"\"\"\n        ret: list[tuple[int, T]] = []\n        while self.next_idx_to_release in self.store:\n            idx = self.next_idx_to_release\n            event = self.store.pop(idx)\n            ret.append((idx, event))\n            self.next_idx_to_release += 1\n        logger.trace(f\"Releasing event {ret}\")\n        return ret\n\n\nclass MultiSourceBuffer[SourceId, T]:\n    \"\"\"\n    A buffer that resequences events to ensure their ordering is preserved.\n    Tracks events with multiple sources\n    \"\"\"\n\n    def __init__(self):\n        self.stores: dict[SourceId, OrderedBuffer[T]] = {}\n\n    def ingest(self, idx: int, t: T, source: SourceId):\n        if source not in self.stores:\n            self.stores[source] = OrderedBuffer()\n        buffer = self.stores[source]\n        buffer.ingest(idx, t)\n\n    def drain(self) -> list[T]:\n        ret: list[T] = []\n        for store in self.stores.values():\n            ret.extend(store.drain())\n        return ret\n"
  },
  {
    "path": "src/exo/utils/fs.py",
    "content": "import contextlib\nimport os\nimport pathlib\nimport tempfile\nfrom typing import LiteralString\n\ntype StrPath = str | os.PathLike[str]\ntype BytesPath = bytes | os.PathLike[bytes]\ntype StrOrBytesPath = str | bytes | os.PathLike[str] | os.PathLike[bytes]\n\n\ndef delete_if_exists(filename: StrOrBytesPath) -> None:\n    with contextlib.suppress(FileNotFoundError):\n        os.remove(filename)\n\n\ndef ensure_parent_directory_exists(filename: StrPath) -> None:\n    \"\"\"\n    Ensure the directory containing the file exists (create it if necessary).\n    \"\"\"\n    pathlib.Path(filename).parent.mkdir(parents=True, exist_ok=True)\n\n\ndef ensure_directory_exists(dirname: StrPath) -> None:\n    \"\"\"\n    Ensure the directory exists (create it if necessary).\n    \"\"\"\n    pathlib.Path(dirname).mkdir(parents=True, exist_ok=True)\n\n\ndef make_temp_path(name: LiteralString) -> str:\n    return os.path.join(tempfile.mkdtemp(), name)\n"
  },
  {
    "path": "src/exo/utils/info_gatherer/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/utils/info_gatherer/info_gatherer.py",
    "content": "import os\nimport shutil\nimport sys\nimport tomllib\nfrom collections.abc import Sequence\nfrom dataclasses import dataclass, field\nfrom subprocess import CalledProcessError\nfrom typing import Self, cast\n\nimport anyio\nfrom anyio import fail_after, open_process, to_thread\nfrom anyio.streams.buffered import BufferedByteReceiveStream\nfrom anyio.streams.text import TextReceiveStream\nfrom loguru import logger\nfrom pydantic import ValidationError\n\nfrom exo.shared.constants import EXO_CONFIG_FILE, EXO_MODELS_DIR\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.profiling import (\n    DiskUsage,\n    MemoryUsage,\n    NetworkInterfaceInfo,\n    ThunderboltBridgeStatus,\n)\nfrom exo.shared.types.thunderbolt import (\n    ThunderboltConnection,\n    ThunderboltConnectivity,\n    ThunderboltIdentifier,\n)\nfrom exo.utils.channels import Sender\nfrom exo.utils.pydantic_ext import TaggedModel\nfrom exo.utils.task_group import TaskGroup\n\nfrom .macmon import MacmonMetrics\nfrom .system_info import (\n    get_friendly_name,\n    get_model_and_chip,\n    get_network_interfaces,\n    get_os_build_version,\n    get_os_version,\n)\n\nIS_DARWIN = sys.platform == \"darwin\"\n\n\nasync def _get_thunderbolt_devices() -> set[str] | None:\n    \"\"\"Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.\n\n    Returns None if the networksetup command fails.\n    \"\"\"\n    result = await anyio.run_process(\n        [\"networksetup\", \"-listallhardwareports\"],\n        check=False,\n    )\n    if result.returncode != 0:\n        logger.warning(\n            f\"networksetup -listallhardwareports failed with code \"\n            f\"{result.returncode}: {result.stderr.decode()}\"\n        )\n        return None\n\n    output = result.stdout.decode()\n    thunderbolt_devices: set[str] = set()\n    current_port: str | None = None\n\n    for line in output.splitlines():\n        line = line.strip()\n        if line.startswith(\"Hardware Port:\"):\n            current_port = line.split(\":\", 1)[1].strip()\n        elif line.startswith(\"Device:\") and current_port:\n            device = line.split(\":\", 1)[1].strip()\n            if \"thunderbolt\" in current_port.lower():\n                thunderbolt_devices.add(device)\n            current_port = None\n\n    return thunderbolt_devices\n\n\nasync def _get_bridge_services() -> dict[str, str] | None:\n    \"\"\"Get mapping of bridge device -> service name from network service order.\n\n    Returns None if the networksetup command fails.\n    \"\"\"\n    result = await anyio.run_process(\n        [\"networksetup\", \"-listnetworkserviceorder\"],\n        check=False,\n    )\n    if result.returncode != 0:\n        logger.warning(\n            f\"networksetup -listnetworkserviceorder failed with code \"\n            f\"{result.returncode}: {result.stderr.decode()}\"\n        )\n        return None\n\n    # Parse service order to find bridge devices and their service names\n    # Format: \"(1) Service Name\\n(Hardware Port: ..., Device: bridge0)\\n\"\n    service_order_output = result.stdout.decode()\n    bridge_services: dict[str, str] = {}  # device -> service name\n    current_service: str | None = None\n\n    for line in service_order_output.splitlines():\n        line = line.strip()\n        # Match \"(N) Service Name\" or \"(*) Service Name\" (disabled)\n        # but NOT \"(Hardware Port: ...)\" lines\n        if (\n            line\n            and line.startswith(\"(\")\n            and \")\" in line\n            and not line.startswith(\"(Hardware Port:\")\n        ):\n            paren_end = line.index(\")\")\n            if paren_end + 2 <= len(line):\n                current_service = line[paren_end + 2 :]\n        # Match \"(Hardware Port: ..., Device: bridgeX)\"\n        elif current_service and \"Device: bridge\" in line:\n            # Extract device name from \"..., Device: bridge0)\"\n            device_start = line.find(\"Device: \") + len(\"Device: \")\n            device_end = line.find(\")\", device_start)\n            if device_end > device_start:\n                device = line[device_start:device_end]\n                bridge_services[device] = current_service\n\n    return bridge_services\n\n\nasync def _get_bridge_members(bridge_device: str) -> set[str]:\n    \"\"\"Get member interfaces of a bridge device via ifconfig.\"\"\"\n    result = await anyio.run_process(\n        [\"ifconfig\", bridge_device],\n        check=False,\n    )\n    if result.returncode != 0:\n        logger.debug(f\"ifconfig {bridge_device} failed with code {result.returncode}\")\n        return set()\n\n    members: set[str] = set()\n    ifconfig_output = result.stdout.decode()\n    for line in ifconfig_output.splitlines():\n        line = line.strip()\n        if line.startswith(\"member:\"):\n            parts = line.split()\n            if len(parts) > 1:\n                members.add(parts[1])\n\n    return members\n\n\nasync def _find_thunderbolt_bridge(\n    bridge_services: dict[str, str], thunderbolt_devices: set[str]\n) -> str | None:\n    \"\"\"Find the service name of a bridge containing Thunderbolt interfaces.\n\n    Returns the service name if found, None otherwise.\n    \"\"\"\n    for bridge_device, service_name in bridge_services.items():\n        members = await _get_bridge_members(bridge_device)\n        if members & thunderbolt_devices:  # intersection is non-empty\n            return service_name\n    return None\n\n\nasync def _is_service_enabled(service_name: str) -> bool | None:\n    \"\"\"Check if a network service is enabled.\n\n    Returns True if enabled, False if disabled, None on error.\n    \"\"\"\n    result = await anyio.run_process(\n        [\"networksetup\", \"-getnetworkserviceenabled\", service_name],\n        check=False,\n    )\n    if result.returncode != 0:\n        logger.warning(\n            f\"networksetup -getnetworkserviceenabled '{service_name}' \"\n            f\"failed with code {result.returncode}: {result.stderr.decode()}\"\n        )\n        return None\n\n    stdout = result.stdout.decode().strip().lower()\n    return stdout == \"enabled\"\n\n\nclass StaticNodeInformation(TaggedModel):\n    \"\"\"Node information that should NEVER change, to be gathered once at startup\"\"\"\n\n    model: str\n    chip: str\n    os_version: str\n    os_build_version: str\n\n    @classmethod\n    async def gather(cls) -> Self:\n        model, chip = await get_model_and_chip()\n        return cls(\n            model=model,\n            chip=chip,\n            os_version=get_os_version(),\n            os_build_version=await get_os_build_version(),\n        )\n\n\nclass NodeNetworkInterfaces(TaggedModel):\n    ifaces: Sequence[NetworkInterfaceInfo]\n\n\nclass MacThunderboltIdentifiers(TaggedModel):\n    idents: Sequence[ThunderboltIdentifier]\n\n\nclass MacThunderboltConnections(TaggedModel):\n    conns: Sequence[ThunderboltConnection]\n\n\nclass RdmaCtlStatus(TaggedModel):\n    enabled: bool\n\n    @classmethod\n    async def gather(cls) -> Self | None:\n        if not IS_DARWIN or shutil.which(\"rdma_ctl\") is None:\n            return None\n        try:\n            with anyio.fail_after(5):\n                proc = await anyio.run_process([\"rdma_ctl\", \"status\"], check=False)\n        except (TimeoutError, OSError):\n            return None\n        if proc.returncode != 0:\n            return None\n        output = proc.stdout.decode(\"utf-8\").lower().strip()\n        if \"enabled\" in output:\n            return cls(enabled=True)\n        if \"disabled\" in output:\n            return cls(enabled=False)\n        return None\n\n\nclass ThunderboltBridgeInfo(TaggedModel):\n    status: ThunderboltBridgeStatus\n\n    @classmethod\n    async def gather(cls) -> Self | None:\n        \"\"\"Check if a Thunderbolt Bridge network service is enabled on this node.\n\n        Detection approach:\n        1. Find all Thunderbolt interface devices (en2, en3, etc.) from hardware ports\n        2. Find bridge devices from network service order (not hardware ports, as\n           bridges may not appear there)\n        3. Check each bridge's members via ifconfig\n        4. If a bridge contains Thunderbolt interfaces, it's a Thunderbolt Bridge\n        5. Check if that network service is enabled\n        \"\"\"\n        if not IS_DARWIN:\n            return None\n\n        def _no_bridge_status() -> Self:\n            return cls(\n                status=ThunderboltBridgeStatus(\n                    enabled=False, exists=False, service_name=None\n                )\n            )\n\n        try:\n            tb_devices = await _get_thunderbolt_devices()\n            if tb_devices is None:\n                return _no_bridge_status()\n\n            bridge_services = await _get_bridge_services()\n            if not bridge_services:\n                return _no_bridge_status()\n\n            tb_service_name = await _find_thunderbolt_bridge(\n                bridge_services, tb_devices\n            )\n            if not tb_service_name:\n                return _no_bridge_status()\n\n            enabled = await _is_service_enabled(tb_service_name)\n            if enabled is None:\n                return cls(\n                    status=ThunderboltBridgeStatus(\n                        enabled=False, exists=True, service_name=tb_service_name\n                    )\n                )\n\n            return cls(\n                status=ThunderboltBridgeStatus(\n                    enabled=enabled,\n                    exists=True,\n                    service_name=tb_service_name,\n                )\n            )\n        except Exception as e:\n            logger.warning(f\"Failed to gather Thunderbolt Bridge info: {e}\")\n            return None\n\n\nclass NodeConfig(TaggedModel):\n    \"\"\"Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there\"\"\"\n\n    @classmethod\n    async def gather(cls) -> Self | None:\n        cfg_file = anyio.Path(EXO_CONFIG_FILE)\n        await cfg_file.parent.mkdir(parents=True, exist_ok=True)\n        await cfg_file.touch(exist_ok=True)\n        async with await cfg_file.open(\"rb\") as f:\n            try:\n                contents = (await f.read()).decode(\"utf-8\")\n                data = tomllib.loads(contents)\n                return cls.model_validate(data)\n            except (tomllib.TOMLDecodeError, UnicodeDecodeError, ValidationError):\n                logger.warning(\"Invalid config file, skipping...\")\n                return None\n\n\nclass MiscData(TaggedModel):\n    \"\"\"Node information that may slowly change that doesn't fall into the other categories\"\"\"\n\n    friendly_name: str\n\n    @classmethod\n    async def gather(cls) -> Self:\n        return cls(friendly_name=await get_friendly_name())\n\n\nclass NodeDiskUsage(TaggedModel):\n    \"\"\"Disk space information for the models directory.\"\"\"\n\n    disk_usage: DiskUsage\n\n    @classmethod\n    async def gather(cls) -> Self:\n        return cls(\n            disk_usage=await to_thread.run_sync(\n                lambda: DiskUsage.from_path(EXO_MODELS_DIR)\n            )\n        )\n\n\nasync def _gather_iface_map() -> dict[str, str] | None:\n    proc = await anyio.run_process(\n        [\"networksetup\", \"-listallhardwareports\"], check=False\n    )\n    if proc.returncode != 0:\n        return None\n\n    ports: dict[str, str] = {}\n    port = \"\"\n    for line in proc.stdout.decode(\"utf-8\").split(\"\\n\"):\n        if line.startswith(\"Hardware Port:\"):\n            port = line.split(\": \")[1]\n        elif line.startswith(\"Device:\"):\n            ports[port] = line.split(\": \")[1]\n            port = \"\"\n    if \"\" in ports:\n        del ports[\"\"]\n    return ports\n\n\nGatheredInfo = (\n    MacmonMetrics\n    | MemoryUsage\n    | NodeNetworkInterfaces\n    | MacThunderboltIdentifiers\n    | MacThunderboltConnections\n    | RdmaCtlStatus\n    | ThunderboltBridgeInfo\n    | NodeConfig\n    | MiscData\n    | StaticNodeInformation\n    | NodeDiskUsage\n)\n\n\n@dataclass\nclass InfoGatherer:\n    info_sender: Sender[GatheredInfo]\n    interface_watcher_interval: float | None = 10\n    misc_poll_interval: float | None = 60\n    system_profiler_interval: float | None = 5 if IS_DARWIN else None\n    memory_poll_rate: float | None = None if IS_DARWIN else 1\n    macmon_interval: float | None = 1 if IS_DARWIN else None\n    thunderbolt_bridge_poll_interval: float | None = 10 if IS_DARWIN else None\n    static_info_poll_interval: float | None = 60\n    rdma_ctl_poll_interval: float | None = 10 if IS_DARWIN else None\n    disk_poll_interval: float | None = 30\n    _tg: TaskGroup = field(init=False, default_factory=TaskGroup)\n\n    async def run(self):\n        async with self._tg as tg:\n            if IS_DARWIN:\n                if (macmon_path := shutil.which(\"macmon\")) is not None:\n                    tg.start_soon(self._monitor_macmon, macmon_path)\n                else:\n                    # macmon not installed — fall back to psutil for memory\n                    logger.warning(\n                        \"macmon not found, falling back to psutil for memory monitoring\"\n                    )\n                    self.memory_poll_rate = 1\n                tg.start_soon(self._monitor_system_profiler_thunderbolt_data)\n                tg.start_soon(self._monitor_thunderbolt_bridge_status)\n                tg.start_soon(self._monitor_rdma_ctl_status)\n            tg.start_soon(self._watch_system_info)\n            tg.start_soon(self._monitor_memory_usage)\n            tg.start_soon(self._monitor_misc)\n            tg.start_soon(self._monitor_static_info)\n            tg.start_soon(self._monitor_disk_usage)\n\n            nc = await NodeConfig.gather()\n            if nc is not None:\n                await self.info_sender.send(nc)\n\n    def shutdown(self):\n        self._tg.cancel_tasks()\n\n    async def _monitor_static_info(self):\n        if self.static_info_poll_interval is None:\n            return\n        while True:\n            try:\n                with fail_after(30):\n                    await self.info_sender.send(await StaticNodeInformation.gather())\n            except Exception as e:\n                logger.warning(f\"Error gathering static node info: {e}\")\n            await anyio.sleep(self.static_info_poll_interval)\n\n    async def _monitor_misc(self):\n        if self.misc_poll_interval is None:\n            return\n        while True:\n            try:\n                with fail_after(10):\n                    await self.info_sender.send(await MiscData.gather())\n            except Exception as e:\n                logger.warning(f\"Error gathering misc data: {e}\")\n            await anyio.sleep(self.misc_poll_interval)\n\n    async def _monitor_system_profiler_thunderbolt_data(self):\n        if self.system_profiler_interval is None:\n            return\n\n        while True:\n            try:\n                with fail_after(30):\n                    iface_map = await _gather_iface_map()\n                    if iface_map is None:\n                        raise ValueError(\"Failed to gather interface map\")\n\n                    data = await ThunderboltConnectivity.gather()\n                    assert data is not None\n\n                    idents = [\n                        it for i in data if (it := i.ident(iface_map)) is not None\n                    ]\n                    await self.info_sender.send(\n                        MacThunderboltIdentifiers(idents=idents)\n                    )\n\n                    conns = [it for i in data if (it := i.conn()) is not None]\n                    await self.info_sender.send(MacThunderboltConnections(conns=conns))\n            except Exception as e:\n                logger.warning(f\"Error gathering Thunderbolt data: {e}\")\n            await anyio.sleep(self.system_profiler_interval)\n\n    async def _monitor_memory_usage(self):\n        override_memory_env = os.getenv(\"OVERRIDE_MEMORY_MB\")\n        override_memory: int | None = (\n            Memory.from_mb(int(override_memory_env)).in_bytes\n            if override_memory_env\n            else None\n        )\n        if self.memory_poll_rate is None:\n            return\n        while True:\n            try:\n                await self.info_sender.send(\n                    MemoryUsage.from_psutil(override_memory=override_memory)\n                )\n            except Exception as e:\n                logger.warning(f\"Error gathering memory usage: {e}\")\n            await anyio.sleep(self.memory_poll_rate)\n\n    async def _watch_system_info(self):\n        if self.interface_watcher_interval is None:\n            return\n        while True:\n            try:\n                with fail_after(10):\n                    nics = await get_network_interfaces()\n                    await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))\n            except Exception as e:\n                logger.warning(f\"Error gathering network interfaces: {e}\")\n            await anyio.sleep(self.interface_watcher_interval)\n\n    async def _monitor_thunderbolt_bridge_status(self):\n        if self.thunderbolt_bridge_poll_interval is None:\n            return\n        while True:\n            try:\n                with fail_after(30):\n                    curr = await ThunderboltBridgeInfo.gather()\n                    if curr is not None:\n                        await self.info_sender.send(curr)\n            except Exception as e:\n                logger.warning(f\"Error gathering Thunderbolt Bridge status: {e}\")\n            await anyio.sleep(self.thunderbolt_bridge_poll_interval)\n\n    async def _monitor_rdma_ctl_status(self):\n        if self.rdma_ctl_poll_interval is None:\n            return\n        while True:\n            try:\n                curr = await RdmaCtlStatus.gather()\n                if curr is not None:\n                    await self.info_sender.send(curr)\n            except Exception as e:\n                logger.warning(f\"Error gathering RDMA ctl status: {e}\")\n            await anyio.sleep(self.rdma_ctl_poll_interval)\n\n    async def _monitor_disk_usage(self):\n        if self.disk_poll_interval is None:\n            return\n        while True:\n            try:\n                with fail_after(5):\n                    await self.info_sender.send(await NodeDiskUsage.gather())\n            except Exception as e:\n                logger.warning(f\"Error gathering disk usage: {e}\")\n            await anyio.sleep(self.disk_poll_interval)\n\n    async def _monitor_macmon(self, macmon_path: str):\n        if self.macmon_interval is None:\n            return\n        # macmon pipe --interval [interval in ms]\n        # Timeout: if macmon produces no output for this many seconds, restart it.\n        # macmon writes every macmon_interval seconds, so 10x that is generous.\n        read_timeout = max(self.macmon_interval * 10, 30)\n        while True:\n            try:\n                async with await open_process(\n                    [\n                        macmon_path,\n                        \"pipe\",\n                        \"--interval\",\n                        str(self.macmon_interval * 1000),\n                    ]\n                ) as p:\n                    if not p.stdout:\n                        logger.critical(\"MacMon closed stdout\")\n                        return\n                    stream = TextReceiveStream(BufferedByteReceiveStream(p.stdout))\n                    while True:\n                        with fail_after(read_timeout):\n                            text = await stream.receive()\n                        await self.info_sender.send(MacmonMetrics.from_raw_json(text))\n            except TimeoutError:\n                logger.warning(\n                    f\"MacMon produced no output for {read_timeout}s, restarting\"\n                )\n            except CalledProcessError as e:\n                stderr_msg = \"no stderr\"\n                stderr_output = cast(bytes | str | None, e.stderr)\n                if stderr_output is not None:\n                    stderr_msg = (\n                        stderr_output.decode()\n                        if isinstance(stderr_output, bytes)\n                        else str(stderr_output)\n                    )\n                logger.warning(\n                    f\"MacMon failed with return code {e.returncode}: {stderr_msg}\"\n                )\n            except Exception as e:\n                logger.warning(f\"Error in macmon monitor: {e}\")\n            await anyio.sleep(self.macmon_interval)\n"
  },
  {
    "path": "src/exo/utils/info_gatherer/macmon.py",
    "content": "from typing import Self\n\nfrom pydantic import BaseModel\n\nfrom exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile\nfrom exo.utils.pydantic_ext import TaggedModel\n\n\nclass _TempMetrics(BaseModel, extra=\"ignore\"):\n    \"\"\"Temperature-related metrics returned by macmon.\"\"\"\n\n    cpu_temp_avg: float\n    gpu_temp_avg: float\n\n\nclass _MemoryMetrics(BaseModel, extra=\"ignore\"):\n    \"\"\"Memory-related metrics returned by macmon.\"\"\"\n\n    ram_total: int\n    ram_usage: int\n    swap_total: int\n    swap_usage: int\n\n\nclass RawMacmonMetrics(BaseModel, extra=\"ignore\"):\n    \"\"\"Complete set of metrics returned by macmon.\n\n    Unknown fields are ignored for forward-compatibility.\n    \"\"\"\n\n    timestamp: str  # ignored\n    temp: _TempMetrics\n    memory: _MemoryMetrics\n    ecpu_usage: tuple[int, float]  # freq mhz, usage %\n    pcpu_usage: tuple[int, float]  # freq mhz, usage %\n    gpu_usage: tuple[int, float]  # freq mhz, usage %\n    all_power: float\n    ane_power: float\n    cpu_power: float\n    gpu_power: float\n    gpu_ram_power: float\n    ram_power: float\n    sys_power: float\n\n\nclass MacmonMetrics(TaggedModel):\n    system_profile: SystemPerformanceProfile\n    memory: MemoryUsage\n\n    @classmethod\n    def from_raw(cls, raw: RawMacmonMetrics) -> Self:\n        return cls(\n            system_profile=SystemPerformanceProfile(\n                gpu_usage=raw.gpu_usage[1],\n                temp=raw.temp.gpu_temp_avg,\n                sys_power=raw.sys_power,\n                pcpu_usage=raw.pcpu_usage[1],\n                ecpu_usage=raw.ecpu_usage[1],\n            ),\n            memory=MemoryUsage.from_bytes(\n                ram_total=raw.memory.ram_total,\n                ram_available=(raw.memory.ram_total - raw.memory.ram_usage),\n                swap_total=raw.memory.swap_total,\n                swap_available=(raw.memory.swap_total - raw.memory.swap_usage),\n            ),\n        )\n\n    @classmethod\n    def from_raw_json(cls, json: str) -> Self:\n        return cls.from_raw(RawMacmonMetrics.model_validate_json(json))\n"
  },
  {
    "path": "src/exo/utils/info_gatherer/net_profile.py",
    "content": "from collections import defaultdict\nfrom collections.abc import AsyncGenerator, Mapping\n\nimport anyio\nimport httpx\nfrom anyio import create_task_group\nfrom loguru import logger\n\nfrom exo.shared.topology import Topology\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.profiling import NodeNetworkInfo\nfrom exo.utils.channels import Sender, channel\n\nREACHABILITY_ATTEMPTS = 3\n\n\nasync def check_reachability(\n    target_ip: str,\n    expected_node_id: NodeId,\n    out: dict[NodeId, set[str]],\n    client: httpx.AsyncClient,\n) -> None:\n    \"\"\"Check if a node is reachable at the given IP and verify its identity.\"\"\"\n    if \":\" in target_ip:\n        # TODO: use real IpAddress types\n        url = f\"http://[{target_ip}]:52415/node_id\"\n    else:\n        url = f\"http://{target_ip}:52415/node_id\"\n\n    remote_node_id = None\n    last_error = None\n\n    for _ in range(REACHABILITY_ATTEMPTS):\n        try:\n            r = await client.get(url)\n            if r.status_code != 200:\n                await anyio.sleep(1)\n                continue\n\n            body = r.text.strip().strip('\"')\n            if not body:\n                await anyio.sleep(1)\n                continue\n\n            remote_node_id = NodeId(body)\n            break\n\n        # expected failure cases\n        except (\n            httpx.TimeoutException,\n            httpx.NetworkError,\n        ):\n            await anyio.sleep(1)\n\n        # other failures should be logged on last attempt\n        except httpx.HTTPError as e:\n            last_error = e\n            await anyio.sleep(1)\n\n    if last_error is not None:\n        logger.warning(\n            f\"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down\"\n        )\n\n    if remote_node_id is None:\n        return\n\n    if remote_node_id != expected_node_id:\n        logger.debug(\n            f\"Discovered node with unexpected node_id; \"\n            f\"ip={target_ip}, expected_node_id={expected_node_id}, \"\n            f\"remote_node_id={remote_node_id}\"\n        )\n        return\n\n    if remote_node_id not in out:\n        out[remote_node_id] = set()\n    out[remote_node_id].add(target_ip)\n\n\nasync def check_reachable(\n    topology: Topology,\n    self_node_id: NodeId,\n    node_network: Mapping[NodeId, NodeNetworkInfo],\n) -> AsyncGenerator[tuple[str, NodeId], None]:\n    \"\"\"Yield (ip, node_id) pairs as reachability probes complete.\"\"\"\n\n    send, recv = channel[tuple[str, NodeId]]()\n\n    # these are intentionally httpx's defaults so we can tune them later\n    timeout = httpx.Timeout(timeout=5.0)\n    limits = httpx.Limits(\n        max_connections=100,\n        max_keepalive_connections=20,\n        keepalive_expiry=5,\n    )\n\n    async def _probe(\n        target_ip: str,\n        expected_node_id: NodeId,\n        client: httpx.AsyncClient,\n        send: Sender[tuple[str, NodeId]],\n    ) -> None:\n        async with send:\n            out: defaultdict[NodeId, set[str]] = defaultdict(set)\n            await check_reachability(target_ip, expected_node_id, out, client)\n            if expected_node_id in out:\n                await send.send((target_ip, expected_node_id))\n\n    async with (\n        httpx.AsyncClient(timeout=timeout, limits=limits, verify=False) as client,\n        create_task_group() as tg,\n    ):\n        for node_id in topology.list_nodes():\n            if node_id not in node_network:\n                continue\n            if node_id == self_node_id:\n                continue\n            for iface in node_network[node_id].interfaces:\n                tg.start_soon(_probe, iface.ip_address, node_id, client, send.clone())\n        send.close()\n\n        with recv:\n            async for item in recv:\n                yield item\n"
  },
  {
    "path": "src/exo/utils/info_gatherer/system_info.py",
    "content": "import platform\nimport socket\nimport sys\nfrom subprocess import CalledProcessError\n\nimport psutil\nfrom anyio import run_process\n\nfrom exo.shared.types.profiling import InterfaceType, NetworkInterfaceInfo\n\n\ndef get_os_version() -> str:\n    \"\"\"Return the OS version string for this node.\n\n    On macOS this is the macOS version (e.g. ``\"15.3\"``).\n    On other platforms it falls back to the platform name (e.g. ``\"Linux\"``).\n    \"\"\"\n    if sys.platform == \"darwin\":\n        version = platform.mac_ver()[0]\n        return version if version else \"Unknown\"\n    return platform.system() or \"Unknown\"\n\n\nasync def get_os_build_version() -> str:\n    \"\"\"Return the macOS build version string (e.g. ``\"24D5055b\"``).\n\n    On non-macOS platforms, returns ``\"Unknown\"``.\n    \"\"\"\n    if sys.platform != \"darwin\":\n        return \"Unknown\"\n\n    try:\n        process = await run_process([\"sw_vers\", \"-buildVersion\"])\n    except CalledProcessError:\n        return \"Unknown\"\n\n    return process.stdout.decode(\"utf-8\", errors=\"replace\").strip() or \"Unknown\"\n\n\nasync def get_friendly_name() -> str:\n    \"\"\"\n    Asynchronously gets the 'Computer Name' (friendly name) of a Mac.\n    e.g., \"John's MacBook Pro\"\n    Returns the name as a string, or None if an error occurs or not on macOS.\n    \"\"\"\n    hostname = socket.gethostname()\n\n    if sys.platform != \"darwin\":\n        return hostname\n\n    try:\n        process = await run_process([\"scutil\", \"--get\", \"ComputerName\"])\n    except CalledProcessError:\n        return hostname\n\n    return process.stdout.decode(\"utf-8\", errors=\"replace\").strip() or hostname\n\n\nasync def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:\n    \"\"\"Parse networksetup -listallhardwareports to get interface types.\"\"\"\n    if sys.platform != \"darwin\":\n        return {}\n\n    try:\n        result = await run_process([\"networksetup\", \"-listallhardwareports\"])\n    except CalledProcessError:\n        return {}\n\n    types: dict[str, InterfaceType] = {}\n    current_type: InterfaceType = \"unknown\"\n\n    for line in result.stdout.decode().splitlines():\n        if line.startswith(\"Hardware Port:\"):\n            port_name = line.split(\":\", 1)[1].strip()\n            if \"Wi-Fi\" in port_name:\n                current_type = \"wifi\"\n            elif \"Ethernet\" in port_name or \"LAN\" in port_name:\n                current_type = \"ethernet\"\n            elif port_name.startswith(\"Thunderbolt\"):\n                current_type = \"thunderbolt\"\n            else:\n                current_type = \"unknown\"\n        elif line.startswith(\"Device:\"):\n            device = line.split(\":\", 1)[1].strip()\n            # enX is ethernet adapters or thunderbolt - these must be deprioritised\n            if device.startswith(\"en\") and device not in [\"en0\", \"en1\"]:\n                current_type = \"maybe_ethernet\"\n            types[device] = current_type\n\n    return types\n\n\nasync def get_network_interfaces() -> list[NetworkInterfaceInfo]:\n    \"\"\"\n    Retrieves detailed network interface information on macOS.\n    Parses output from 'networksetup -listallhardwareports' and 'ifconfig'\n    to determine interface names, IP addresses, and types (ethernet, wifi, vpn, other).\n    Returns a list of NetworkInterfaceInfo objects.\n    \"\"\"\n    interfaces_info: list[NetworkInterfaceInfo] = []\n    interface_types = await _get_interface_types_from_networksetup()\n\n    for iface, services in psutil.net_if_addrs().items():\n        for service in services:\n            match service.family:\n                case socket.AF_INET | socket.AF_INET6:\n                    interfaces_info.append(\n                        NetworkInterfaceInfo(\n                            name=iface,\n                            ip_address=service.address,\n                            interface_type=interface_types.get(iface, \"unknown\"),\n                        )\n                    )\n                case _:\n                    pass\n\n    return interfaces_info\n\n\nasync def get_model_and_chip() -> tuple[str, str]:\n    \"\"\"Get Mac system information using system_profiler.\"\"\"\n    model = \"Unknown Model\"\n    chip = \"Unknown Chip\"\n\n    # TODO: better non mac support\n    if sys.platform != \"darwin\":\n        return (model, chip)\n\n    try:\n        process = await run_process(\n            [\n                \"system_profiler\",\n                \"SPHardwareDataType\",\n            ]\n        )\n    except CalledProcessError:\n        return (model, chip)\n\n    # less interested in errors here because this value should be hard coded\n    output = process.stdout.decode().strip()\n\n    model_line = next(\n        (line for line in output.split(\"\\n\") if \"Model Name\" in line), None\n    )\n    model = model_line.split(\": \")[1] if model_line else \"Unknown Model\"\n\n    chip_line = next((line for line in output.split(\"\\n\") if \"Chip\" in line), None)\n    chip = chip_line.split(\": \")[1] if chip_line else \"Unknown Chip\"\n\n    return (model, chip)\n"
  },
  {
    "path": "src/exo/utils/info_gatherer/tests/test_tb_parsing.py",
    "content": "import sys\n\nimport pytest\n\nfrom exo.shared.types.thunderbolt import (\n    ThunderboltConnectivity,\n)\nfrom exo.utils.info_gatherer.info_gatherer import (\n    _gather_iface_map,  # pyright: ignore[reportPrivateUsage]\n)\n\n\n@pytest.mark.anyio\n@pytest.mark.skipif(\n    sys.platform != \"darwin\", reason=\"Thunderbolt info can only be gathered on macos\"\n)\nasync def test_tb_parsing():\n    data = await ThunderboltConnectivity.gather()\n    ifaces = await _gather_iface_map()\n    assert ifaces\n    assert data\n    for datum in data:\n        datum.ident(ifaces)\n        datum.conn()\n"
  },
  {
    "path": "src/exo/utils/keyed_backoff.py",
    "content": "import time\nfrom typing import Generic, TypeVar\n\nK = TypeVar(\"K\")\n\n\nclass KeyedBackoff(Generic[K]):\n    \"\"\"Tracks exponential backoff state per key.\"\"\"\n\n    def __init__(self, base: float = 0.5, cap: float = 10.0):\n        self._base = base\n        self._cap = cap\n        self._attempts: dict[K, int] = {}\n        self._last_time: dict[K, float] = {}\n\n    def should_proceed(self, key: K) -> bool:\n        \"\"\"Returns True if enough time has elapsed since last attempt.\"\"\"\n        now = time.monotonic()\n        last = self._last_time.get(key, 0.0)\n        attempts = self._attempts.get(key, 0)\n        delay = min(self._cap, self._base * (2.0**attempts))\n        return now - last >= delay\n\n    def record_attempt(self, key: K) -> None:\n        \"\"\"Record that an attempt was made for this key.\"\"\"\n        self._last_time[key] = time.monotonic()\n        self._attempts[key] = self._attempts.get(key, 0) + 1\n\n    def reset(self, key: K) -> None:\n        \"\"\"Reset backoff state for a key (e.g., on success).\"\"\"\n        self._attempts.pop(key, None)\n        self._last_time.pop(key, None)\n"
  },
  {
    "path": "src/exo/utils/phantom.py",
    "content": "class _PhantomData[*T]:\n    \"\"\"\n    Internal machinery of the phantom data - it stores nothing.\n    \"\"\"\n\n\ntype PhantomData[*T] = _PhantomData[*T] | None\n\"\"\"\nAllows you to use generics in functions without storing anything of that generic type. \nJust use `None` and you'll be fine\n\"\"\"\n"
  },
  {
    "path": "src/exo/utils/power_sampler.py",
    "content": "import time\nfrom collections import defaultdict\nfrom collections.abc import Callable, Mapping\nfrom typing import final\n\nimport anyio\n\nfrom exo.api.types import NodePowerStats, PowerUsage\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.profiling import SystemPerformanceProfile\n\n\n@final\nclass PowerSampler:\n    def __init__(\n        self,\n        get_node_system: Callable[[], Mapping[NodeId, SystemPerformanceProfile]],\n        interval: float = 1.0,\n    ):\n        self._get_node_system = get_node_system\n        self._interval = interval\n        self._samples: defaultdict[NodeId, list[SystemPerformanceProfile]] = (\n            defaultdict(list)\n        )\n        self._start_time: float | None = None\n        self._stopped = False\n\n    def _take_sample(self) -> None:\n        for node_id, profile in self._get_node_system().items():\n            self._samples[node_id].append(profile)\n\n    async def run(self) -> None:\n        self._start_time = time.perf_counter()\n        self._take_sample()\n        while not self._stopped:\n            await anyio.sleep(self._interval)\n            self._take_sample()\n\n    def result(self) -> PowerUsage:\n        self._stopped = True\n        assert self._start_time is not None, \"result() called before run()\"\n        self._take_sample()\n        elapsed = time.perf_counter() - self._start_time\n\n        node_stats: list[NodePowerStats] = []\n        for node_id, profiles in self._samples.items():\n            n = len(profiles)\n            if n == 0:\n                continue\n            node_stats.append(\n                NodePowerStats(\n                    node_id=node_id,\n                    samples=n,\n                    avg_sys_power=sum(p.sys_power for p in profiles) / n,\n                )\n            )\n\n        total_avg_sys = sum(ns.avg_sys_power for ns in node_stats)\n        return PowerUsage(\n            elapsed_seconds=elapsed,\n            nodes=node_stats,\n            total_avg_sys_power_watts=total_avg_sys,\n            total_energy_joules=total_avg_sys * elapsed,\n        )\n"
  },
  {
    "path": "src/exo/utils/pydantic_ext.py",
    "content": "# pyright: reportAny=false, reportUnknownArgumentType=false, reportUnknownVariableType=false\n\nfrom typing import Any, Self\n\nfrom pydantic import BaseModel, ConfigDict, model_serializer, model_validator\nfrom pydantic.alias_generators import to_camel\nfrom pydantic_core.core_schema import (\n    SerializerFunctionWrapHandler,\n    ValidatorFunctionWrapHandler,\n)\n\n\nclass CamelCaseModel(BaseModel):\n    \"\"\"\n    A model whose fields are aliased to camel-case from snake-case.\n    \"\"\"\n\n    model_config = ConfigDict(\n        alias_generator=to_camel,\n        validate_by_name=True,\n        extra=\"forbid\",\n        strict=True,\n    )\n\n\nclass FrozenModel(BaseModel):\n    model_config = ConfigDict(\n        alias_generator=to_camel,\n        validate_by_name=True,\n        extra=\"forbid\",\n        strict=True,\n        frozen=True,\n    )\n\n\nclass TaggedModel(CamelCaseModel):\n    @model_serializer(mode=\"wrap\")\n    def _serialize(self, handler: SerializerFunctionWrapHandler):\n        inner = handler(self)\n        return {self.__class__.__name__: inner}\n\n    @model_validator(mode=\"wrap\")\n    @classmethod\n    def _validate(cls, v: Any, handler: ValidatorFunctionWrapHandler) -> Self:\n        if isinstance(v, dict) and len(v) == 1 and cls.__name__ in v:\n            return handler(v[cls.__name__])\n\n        return handler(v)\n\n    def __str__(self) -> str:\n        return f\"{self.__class__.__name__}({super().__str__()})\"\n"
  },
  {
    "path": "src/exo/utils/reactive.py",
    "content": "\"\"\"\nUtilities for reactive variables\n\n\"\"\"\n\nfrom typing import Protocol\n\n\nclass OnChange[T](Protocol):\n    def __call__(self, old_value: T, new_value: T) -> None: ...\n\n\nclass Reactive[T]:\n    def __init__(self, initial_value: T, on_change: OnChange[T]):\n        self._value = initial_value\n        self._on_change = on_change\n\n    @property\n    def value(self):\n        return self._value\n\n    @value.setter\n    def value(self, new_value: T):\n        old_value = self._value\n        self._value = new_value\n\n        # don't notify when not changed\n        if old_value == new_value:\n            return\n\n        # notify of changes\n        self._on_change(old_value=old_value, new_value=new_value)\n"
  },
  {
    "path": "src/exo/utils/task_group.py",
    "content": "from collections.abc import Awaitable, Callable\nfrom dataclasses import dataclass, field\nfrom types import TracebackType\nfrom typing import Any, Unpack\n\nfrom anyio import create_task_group\nfrom anyio.abc import TaskGroup as TaskGroupABC\n\n\n@dataclass\nclass TaskGroup:\n    _tg: TaskGroupABC | None = field(default=None, init=False)\n    _queued: list[tuple[Any, Any, Any]] | None = field(default_factory=list, init=False)\n\n    def is_running(self) -> bool:\n        return self._tg is not None\n\n    def cancel_tasks(self):\n        assert self._tg\n        self._tg.cancel_scope.cancel()\n\n    def cancel_called(self) -> bool:\n        assert self._tg\n        return self._tg.cancel_scope.cancel_called\n\n    def start_soon[*T](\n        self,\n        func: Callable[[Unpack[T]], Awaitable[Any]],\n        *args: Unpack[T],\n        name: object = None,\n    ) -> None:\n        assert self._tg is not None\n        assert self._queued is None\n        self._tg.start_soon(func, *args, name=name)\n\n    def queue[*T](\n        self,\n        func: Callable[[Unpack[T]], Awaitable[Any]],\n        *args: Unpack[T],\n        name: object = None,\n    ) -> None:\n        assert self._tg is None\n        assert self._queued is not None\n        self._queued.append((func, args, name))\n\n    async def __aenter__(self) -> TaskGroupABC:\n        assert self._tg is None\n        assert self._queued is not None\n        self._tg = create_task_group()\n        r = await self._tg.__aenter__()\n        for func, args, name in self._queued:  # pyright: ignore[reportAny]\n            self._tg.start_soon(func, *args, name=name)  # pyright: ignore[reportAny]\n        self._queued = None\n        return r\n\n    async def __aexit__(\n        self,\n        exc_type: type[BaseException] | None,\n        exc_val: BaseException | None,\n        exc_tb: TracebackType | None,\n    ) -> bool:\n        \"\"\"Exit the task group context waiting for all tasks to finish.\"\"\"\n        assert self._tg is not None, \"aenter sets self.lazy, so it exists when we aexit\"\n        assert self._queued is None\n        return await self._tg.__aexit__(exc_type, exc_val, exc_tb)\n"
  },
  {
    "path": "src/exo/utils/tests/test_event_log.py",
    "content": "from pathlib import Path\n\nimport pytest\n\nfrom exo.shared.types.events import TestEvent\nfrom exo.utils.disk_event_log import DiskEventLog\n\n\n@pytest.fixture\ndef log_dir(tmp_path: Path) -> Path:\n    return tmp_path / \"event_log\"\n\n\ndef test_append_and_read_back(log_dir: Path):\n    log = DiskEventLog(log_dir)\n    events = [TestEvent() for _ in range(5)]\n    for e in events:\n        log.append(e)\n\n    assert len(log) == 5\n\n    result = list(log.read_all())\n    assert len(result) == 5\n    for original, restored in zip(events, result, strict=True):\n        assert original.event_id == restored.event_id\n\n    log.close()\n\n\ndef test_read_range(log_dir: Path):\n    log = DiskEventLog(log_dir)\n    events = [TestEvent() for _ in range(10)]\n    for e in events:\n        log.append(e)\n\n    result = list(log.read_range(3, 7))\n    assert len(result) == 4\n    for i, restored in enumerate(result):\n        assert events[3 + i].event_id == restored.event_id\n\n    log.close()\n\n\ndef test_read_range_bounds(log_dir: Path):\n    log = DiskEventLog(log_dir)\n    events = [TestEvent() for _ in range(3)]\n    for e in events:\n        log.append(e)\n\n    # Start beyond count\n    assert list(log.read_range(5, 10)) == []\n    # Negative start\n    assert list(log.read_range(-1, 2)) == []\n    # End beyond count is clamped\n    result = list(log.read_range(1, 100))\n    assert len(result) == 2\n\n    log.close()\n\n\ndef test_empty_log(log_dir: Path):\n    log = DiskEventLog(log_dir)\n    assert len(log) == 0\n    assert list(log.read_all()) == []\n    assert list(log.read_range(0, 10)) == []\n    log.close()\n\n\ndef _archives(log_dir: Path) -> list[Path]:\n    return sorted(log_dir.glob(\"events.*.bin.zst\"))\n\n\ndef test_rotation_on_close(log_dir: Path):\n    log = DiskEventLog(log_dir)\n    log.append(TestEvent())\n    log.close()\n\n    active = log_dir / \"events.bin\"\n    assert not active.exists()\n\n    archives = _archives(log_dir)\n    assert len(archives) == 1\n    assert archives[0].stat().st_size > 0\n\n\ndef test_rotation_on_construction_with_stale_file(log_dir: Path):\n    log_dir.mkdir(parents=True, exist_ok=True)\n    (log_dir / \"events.bin\").write_bytes(b\"stale data\")\n\n    log = DiskEventLog(log_dir)\n    archives = _archives(log_dir)\n    assert len(archives) == 1\n    assert archives[0].exists()\n    assert len(log) == 0\n\n    log.close()\n\n\ndef test_empty_log_no_archive(log_dir: Path):\n    \"\"\"Closing an empty log should not leave an archive.\"\"\"\n    log = DiskEventLog(log_dir)\n    log.close()\n\n    active = log_dir / \"events.bin\"\n\n    assert not active.exists()\n    assert _archives(log_dir) == []\n\n\ndef test_close_is_idempotent(log_dir: Path):\n    log = DiskEventLog(log_dir)\n    log.append(TestEvent())\n    log.close()\n    archive = _archives(log_dir)\n    log.close()  # should not raise\n\n    assert _archives(log_dir) == archive\n\n\ndef test_successive_sessions(log_dir: Path):\n    \"\"\"Simulate two master sessions: both archives should be kept.\"\"\"\n    log1 = DiskEventLog(log_dir)\n    log1.append(TestEvent())\n    log1.close()\n\n    first_archive = _archives(log_dir)[-1]\n\n    log2 = DiskEventLog(log_dir)\n    log2.append(TestEvent())\n    log2.append(TestEvent())\n    log2.close()\n\n    # Session 1 archive shifted to slot 2, session 2 in slot 1\n    second_archive = _archives(log_dir)[-1]\n    should_be_first_archive = _archives(log_dir)[-2]\n\n    assert first_archive.exists()\n    assert second_archive.exists()\n    assert first_archive != second_archive\n    assert should_be_first_archive == first_archive\n\n\ndef test_rotation_keeps_at_most_5_archives(log_dir: Path):\n    \"\"\"After 7 sessions, only the 5 most recent archives should remain.\"\"\"\n    all_archives: list[Path] = []\n    for _ in range(7):\n        log = DiskEventLog(log_dir)\n        log.append(TestEvent())\n        log.close()\n        all_archives.append(_archives(log_dir)[-1])\n\n    for old in all_archives[:2]:\n        assert not old.exists()\n    for recent in all_archives[2:]:\n        assert recent.exists()\n"
  },
  {
    "path": "src/exo/utils/tests/test_mp_channel.py",
    "content": "import multiprocessing as mp\nimport time\n\nimport pytest\nfrom anyio import fail_after\nfrom loguru import logger\n\nfrom exo.utils.channels import MpReceiver, MpSender, mp_channel\n\n\ndef foo(recv: MpReceiver[str]):\n    expected = [\"hi\", \"hi 2\", \"bye\"]\n    with recv as r:\n        for item in r:\n            assert item == expected.pop(0)\n\n\ndef bar(send: MpSender[str]):\n    logger.warning(\"hi\")\n    send.send(\"hi\")\n    time.sleep(0.1)\n    logger.warning(\"hi 2\")\n    send.send(\"hi 2\")\n    time.sleep(0.1)\n    logger.warning(\"bye\")\n    send.send(\"bye\")\n    time.sleep(0.1)\n    send.close()\n\n\n@pytest.mark.anyio\nasync def test_channel_ipc():\n    with fail_after(0.5):\n        s, r = mp_channel[str]()\n        p1 = mp.Process(target=foo, args=(r,))\n        p2 = mp.Process(target=bar, args=(s,))\n        p1.start()\n        p2.start()\n        p1.join()\n        p2.join()\n"
  },
  {
    "path": "src/exo/utils/tests/test_power_sampler.py",
    "content": "from collections.abc import Mapping\n\nimport anyio\nimport pytest\n\nfrom exo.api.types import PowerUsage\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.profiling import SystemPerformanceProfile\nfrom exo.utils.power_sampler import PowerSampler\n\n\ndef _make_profile(sys_power: float) -> SystemPerformanceProfile:\n    return SystemPerformanceProfile(sys_power=sys_power)\n\n\nNODE_A = NodeId(\"node-a\")\nNODE_B = NodeId(\"node-b\")\n\n\n@pytest.fixture\ndef single_node_sampler() -> PowerSampler:\n    state: dict[NodeId, SystemPerformanceProfile] = {\n        NODE_A: _make_profile(10.0),\n    }\n    return PowerSampler(get_node_system=lambda: state)\n\n\n@pytest.fixture\ndef multi_node_state() -> dict[NodeId, SystemPerformanceProfile]:\n    return {\n        NODE_A: _make_profile(10.0),\n        NODE_B: _make_profile(20.0),\n    }\n\n\nasync def test_single_sample(single_node_sampler: PowerSampler) -> None:\n    \"\"\"A sampler that runs briefly should capture at least the initial sample.\"\"\"\n    async with anyio.create_task_group() as tg:\n        tg.start_soon(single_node_sampler.run)\n        await anyio.sleep(0.05)\n        tg.cancel_scope.cancel()\n\n    result = single_node_sampler.result()\n    assert len(result.nodes) == 1\n    assert result.nodes[0].node_id == NODE_A\n    assert result.nodes[0].avg_sys_power == 10.0\n    assert result.nodes[0].samples >= 1\n    assert result.elapsed_seconds > 0\n\n\nasync def test_multi_node_averaging(\n    multi_node_state: dict[NodeId, SystemPerformanceProfile],\n) -> None:\n    \"\"\"Power from multiple nodes should be summed for total cluster power.\"\"\"\n    sampler = PowerSampler(get_node_system=lambda: multi_node_state)\n    async with anyio.create_task_group() as tg:\n        tg.start_soon(sampler.run)\n        await anyio.sleep(0.05)\n        tg.cancel_scope.cancel()\n\n    result = sampler.result()\n    assert len(result.nodes) == 2\n    assert result.total_avg_sys_power_watts == 30.0\n\n\nasync def test_energy_calculation(single_node_sampler: PowerSampler) -> None:\n    \"\"\"Energy (joules) should be avg_power * elapsed_seconds.\"\"\"\n    async with anyio.create_task_group() as tg:\n        tg.start_soon(single_node_sampler.run)\n        await anyio.sleep(0.3)\n        tg.cancel_scope.cancel()\n\n    result = single_node_sampler.result()\n    expected_energy = result.total_avg_sys_power_watts * result.elapsed_seconds\n    assert result.total_energy_joules == expected_energy\n\n\nasync def test_changing_power_is_averaged() -> None:\n    \"\"\"When power changes mid-sampling, the result should be the average.\"\"\"\n    state: dict[NodeId, SystemPerformanceProfile] = {\n        NODE_A: _make_profile(10.0),\n    }\n    sampler = PowerSampler(get_node_system=lambda: state, interval=0.05)\n\n    async with anyio.create_task_group() as tg:\n        tg.start_soon(sampler.run)\n        await anyio.sleep(0.15)\n        state[NODE_A] = _make_profile(20.0)\n        await anyio.sleep(0.15)\n        tg.cancel_scope.cancel()\n\n    result = sampler.result()\n    avg = result.nodes[0].avg_sys_power\n    # Should be between 10 and 20, not exactly either\n    assert 10.0 < avg < 20.0\n\n\nasync def test_empty_state() -> None:\n    \"\"\"A sampler with no nodes should return an empty result.\"\"\"\n    empty: Mapping[NodeId, SystemPerformanceProfile] = {}\n    sampler = PowerSampler(get_node_system=lambda: empty)\n\n    async with anyio.create_task_group() as tg:\n        tg.start_soon(sampler.run)\n        await anyio.sleep(0.05)\n        tg.cancel_scope.cancel()\n\n    result = sampler.result()\n    assert len(result.nodes) == 0\n    assert result.total_avg_sys_power_watts == 0.0\n    assert result.total_energy_joules == 0.0\n\n\nasync def test_result_stops_sampling() -> None:\n    \"\"\"Calling result() should stop the sampler's run loop.\"\"\"\n    state: dict[NodeId, SystemPerformanceProfile] = {\n        NODE_A: _make_profile(10.0),\n    }\n    sampler = PowerSampler(get_node_system=lambda: state, interval=0.02)\n\n    result: PowerUsage | None = None\n    async with anyio.create_task_group() as tg:\n        tg.start_soon(sampler.run)\n        await anyio.sleep(0.1)\n        result = sampler.result()\n        # run() should exit on its own since _stopped is True\n        await anyio.sleep(0.1)\n        tg.cancel_scope.cancel()\n\n    assert result is not None\n    assert result.nodes[0].samples >= 2\n"
  },
  {
    "path": "src/exo/utils/tests/test_tagged.py",
    "content": "import anyio\nimport pytest\nfrom pydantic import BaseModel, TypeAdapter, ValidationError\n\nfrom exo.utils.pydantic_ext import TaggedModel\n\n\ndef test_plain_union_prefers_first_member_when_shapes_are_identical():\n    class Foo1(BaseModel):\n        x: int\n\n    class Foo2(BaseModel):\n        x: int\n\n    # Base Pydantic behavior: ambiguous dict goes to the first union member\n    ta = TypeAdapter[Foo1 | Foo2](Foo1 | Foo2)\n    out = ta.validate_python({\"x\": 1})\n    assert isinstance(out, Foo1), (\n        \"Base Pydantic should pick the first union member for identical shapes\"\n    )\n\n\ndef test_tagged_union_serializes_and_deserializes_two_identical_shapes_correctly():\n    class Foo1(TaggedModel):\n        x: int\n\n    class Foo2(TaggedModel):\n        x: int\n\n    t1 = Foo1(x=1)\n    assert t1.model_dump() == {\"Foo1\": {\"x\": 1}}\n\n    t2 = Foo2(x=2)\n    assert t2.model_dump() == {\"Foo2\": {\"x\": 2}}\n\n    # ---- deserialize (TypeAdapter -> model_validator(before)) ----\n    ta = TypeAdapter[Foo1 | Foo2](Foo1 | Foo2)\n\n    out1 = ta.validate_python({\"Foo1\": {\"x\": 10}})\n    assert isinstance(out1, Foo1) and out1.x == 10\n\n    out2 = ta.validate_python({\"Foo2\": {\"x\": 20}})\n    assert isinstance(out2, Foo2) and out2.x == 20\n\n\ndef test_tagged_union_rejects_unknown_tag():\n    class Foo1(TaggedModel):\n        x: int\n\n    class Foo2(TaggedModel):\n        x: int\n\n    ta = TypeAdapter[Foo1 | Foo2](Foo1 | Foo2)\n    with pytest.raises(ValidationError):\n        ta.validate_python({\"NotARealTag\": {\"x\": 0}})\n\n\ndef test_two_tagged_classes_with_different_shapes_are_independent_and_not_cross_deserializable():\n    class A1(TaggedModel):\n        x: int\n\n    class A2(TaggedModel):\n        name: str\n\n    class B1(TaggedModel):\n        name: str\n\n    class B2(TaggedModel):\n        active: bool\n\n    a_payload = A1(x=123).model_dump()\n    b_payload = B1(name=\"neo\").model_dump()\n\n    assert a_payload == {\"A1\": {\"x\": 123}}\n    assert b_payload == {\"B1\": {\"name\": \"neo\"}}\n\n    ta_a = TypeAdapter[A1 | A2](A1 | A2)\n    ta_b = TypeAdapter[B1 | B2](B1 | B2)\n\n    with pytest.raises(ValidationError):\n        ta_a.validate_python(b_payload)\n\n    with pytest.raises(ValidationError):\n        ta_b.validate_python(a_payload)\n\n\nclass Inner(TaggedModel):\n    x: int\n\n\nclass Outer(TaggedModel):\n    inner: Inner\n\n\nclass Wrapper(TaggedModel):\n    outer: Outer\n    label: str\n\n\nclass Container(TaggedModel):\n    items: list[Inner]\n    nested: Wrapper\n\n\ndef test_single_level_tagging():\n    inner = Inner(x=10)\n    dumped = inner.model_dump()\n    assert dumped == {\"Inner\": {\"x\": 10}}\n\n    restored = Inner.model_validate(dumped)\n    assert isinstance(restored, Inner)\n    assert restored.x == 10\n\n\ndef test_nested_externally_tagged_union_serializes_recursively():\n    outer = Outer(inner=Inner(x=42))\n    dumped = outer.model_dump()\n\n    assert dumped == {\"Outer\": {\"inner\": {\"Inner\": {\"x\": 42}}}}\n\n    restored = Outer.model_validate(dumped)\n    assert isinstance(restored.inner, Inner)\n    assert restored.inner.x == 42\n\n\ndef test_two_level_nested_tagging():\n    outer = Outer(inner=Inner(x=123))\n    dumped = outer.model_dump()\n    assert dumped == {\"Outer\": {\"inner\": {\"Inner\": {\"x\": 123}}}}\n\n    restored = Outer.model_validate(dumped)\n    assert isinstance(restored.inner, Inner)\n    assert restored.inner.x == 123\n\n\ndef test_three_level_nested_tagging():\n    wrapper = Wrapper(label=\"deep\", outer=Outer(inner=Inner(x=7)))\n    dumped = wrapper.model_dump()\n    # 3-level structure, each with exactly one tag\n    assert dumped == {\n        \"Wrapper\": {\n            \"label\": \"deep\",\n            \"outer\": {\"Outer\": {\"inner\": {\"Inner\": {\"x\": 7}}}},\n        }\n    }\n\n    restored = Wrapper.model_validate(dumped)\n    assert isinstance(restored.outer.inner, Inner)\n    assert restored.outer.inner.x == 7\n    assert restored.label == \"deep\"\n\n\ndef test_lists_and_mixed_nested_structures():\n    container = Container(\n        items=[Inner(x=1), Inner(x=2)],\n        nested=Wrapper(label=\"mix\", outer=Outer(inner=Inner(x=9))),\n    )\n    dumped = container.model_dump()\n\n    assert dumped == {\n        \"Container\": {\n            \"items\": [\n                {\"Inner\": {\"x\": 1}},\n                {\"Inner\": {\"x\": 2}},\n            ],\n            \"nested\": {\n                \"Wrapper\": {\n                    \"label\": \"mix\",\n                    \"outer\": {\"Outer\": {\"inner\": {\"Inner\": {\"x\": 9}}}},\n                }\n            },\n        }\n    }\n\n    restored = Container.model_validate(dumped)\n    assert isinstance(restored.nested.outer.inner, Inner)\n    assert [i.x for i in restored.items] == [1, 2]\n\n\ndef test_no_double_tagging_on_repeated_calls():\n    \"\"\"Ensure multiple model_dump calls don't stack tags.\"\"\"\n    inner = Inner(x=11)\n    dumped1 = inner.model_dump()\n    dumped2 = inner.model_dump()\n    assert dumped1 == dumped2 == {\"Inner\": {\"x\": 11}}\n\n    outer = Outer(inner=inner)\n    d1 = outer.model_dump()\n    d2 = outer.model_dump()\n    assert d1 == d2 == {\"Outer\": {\"inner\": {\"Inner\": {\"x\": 11}}}}\n\n\nclass L3A(TaggedModel):\n    x: int\n\n\nclass L3B(TaggedModel):\n    x: int\n\n\nclass L3C(TaggedModel):\n    x: int\n\n\nL3 = L3A | L3B | L3C\n\n\nclass L2A(TaggedModel):\n    child: L3\n\n\nclass L2B(TaggedModel):\n    child: L3\n\n\nclass L2C(TaggedModel):\n    child: L3\n\n\nL2 = L2A | L2B | L2C\n\n\nclass L1A(TaggedModel):\n    child: L2\n\n\nclass L1B(TaggedModel):\n    child: L2\n\n\nclass L1C(TaggedModel):\n    child: L2\n\n\nL1 = L1A | L1B | L1C\n\n\n@pytest.mark.anyio\nasync def test_tagged_union_is_fast():\n    # payload along the \"C\" path (worst case for DFS if branches are tried A->B->C)\n    payload = {\"L1C\": {\"child\": {\"L2C\": {\"child\": {\"L3C\": {\"x\": 123}}}}}}\n\n    with anyio.fail_after(0.1):\n        out = TypeAdapter(L1).validate_python(payload)  # type: ignore\n\n    # Sanity check the result\n    assert out.__class__.__name__ == \"L1C\"  # type: ignore\n    assert out.child.__class__.__name__ == \"L2C\"  # type: ignore\n    assert out.child.child.__class__.__name__ == \"L3C\"  # type: ignore\n    assert out.child.child.x == 123  # type: ignore\n"
  },
  {
    "path": "src/exo/worker/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/worker/engines/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/worker/engines/image/__init__.py",
    "content": "from exo.worker.engines.image.distributed_model import (\n    DistributedImageModel,\n    initialize_image_model,\n)\nfrom exo.worker.engines.image.generate import generate_image, warmup_image_generator\n\n__all__ = [\n    \"DistributedImageModel\",\n    \"generate_image\",\n    \"initialize_image_model\",\n    \"warmup_image_generator\",\n]\n"
  },
  {
    "path": "src/exo/worker/engines/image/config.py",
    "content": "from enum import Enum\n\nfrom pydantic import BaseModel\n\n\nclass BlockType(Enum):\n    JOINT = \"joint\"  # Separate image/text streams\n    SINGLE = \"single\"  # Concatenated streams\n\n\nclass TransformerBlockConfig(BaseModel):\n    model_config = {\"frozen\": True}\n\n    block_type: BlockType\n    count: int\n    has_separate_text_output: bool  # True for joint blocks that output text separately\n\n\nclass ImageModelConfig(BaseModel):\n    model_family: str\n\n    block_configs: tuple[TransformerBlockConfig, ...]\n\n    default_steps: dict[str, int]  # {\"low\": X, \"medium\": Y, \"high\": Z}\n    num_sync_steps: int  # Number of sync steps for distributed inference\n\n    guidance_scale: float | None = None  # None or <= 1.0 disables CFG\n\n    @property\n    def total_blocks(self) -> int:\n        return sum(bc.count for bc in self.block_configs)\n\n    @property\n    def joint_block_count(self) -> int:\n        return sum(\n            bc.count for bc in self.block_configs if bc.block_type == BlockType.JOINT\n        )\n\n    @property\n    def single_block_count(self) -> int:\n        return sum(\n            bc.count for bc in self.block_configs if bc.block_type == BlockType.SINGLE\n        )\n\n    def get_steps_for_quality(self, quality: str) -> int:\n        return self.default_steps[quality]\n"
  },
  {
    "path": "src/exo/worker/engines/image/distributed_model.py",
    "content": "from collections.abc import Generator\nfrom pathlib import Path\nfrom typing import Any, Literal, Optional\n\nimport mlx.core as mx\nfrom mflux.models.common.config.config import Config\nfrom PIL import Image\n\nfrom exo.api.types import AdvancedImageParams\nfrom exo.download.download_utils import build_model_path\nfrom exo.shared.types.worker.instances import BoundInstance\nfrom exo.shared.types.worker.shards import CfgShardMetadata, PipelineShardMetadata\nfrom exo.worker.engines.image.config import ImageModelConfig\nfrom exo.worker.engines.image.models import (\n    create_adapter_for_model,\n    get_config_for_model,\n)\nfrom exo.worker.engines.image.models.base import ModelAdapter\nfrom exo.worker.engines.image.pipeline import DiffusionRunner\nfrom exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier\nfrom exo.worker.runner.bootstrap import logger\n\n\nclass DistributedImageModel:\n    _config: ImageModelConfig\n    _adapter: ModelAdapter[Any, Any]\n    _runner: DiffusionRunner\n\n    def __init__(\n        self,\n        model_id: str,\n        local_path: Path,\n        shard_metadata: PipelineShardMetadata | CfgShardMetadata,\n        group: Optional[mx.distributed.Group] = None,\n        quantize: int | None = None,\n    ):\n        config = get_config_for_model(model_id)\n        adapter = create_adapter_for_model(config, model_id, local_path, quantize)\n\n        has_layer_sharding = (\n            shard_metadata.start_layer != 0\n            or shard_metadata.end_layer != shard_metadata.n_layers\n        )\n\n        if group is not None and has_layer_sharding:\n            adapter.slice_transformer_blocks(\n                start_layer=shard_metadata.start_layer,\n                end_layer=shard_metadata.end_layer,\n            )\n\n        runner = DiffusionRunner(\n            config=config,\n            adapter=adapter,\n            group=group,\n            shard_metadata=shard_metadata,\n        )\n\n        if group is not None:\n            logger.info(\"Initialized distributed diffusion runner\")\n\n            mx.eval(adapter.model.parameters())  # pyright: ignore[reportAny]\n\n            # TODO(ciaran): Do we need this?\n            mx.eval(adapter.model)  # pyright: ignore[reportAny]\n\n            mx_barrier(group)\n            logger.info(f\"Transformer sharded for rank {group.rank()}\")\n        else:\n            logger.info(\"Single-node initialization\")\n\n        self._config = config\n        self._adapter = adapter\n        self._runner = runner\n\n    @classmethod\n    def from_bound_instance(\n        cls, bound_instance: BoundInstance\n    ) -> \"DistributedImageModel\":\n        model_id = bound_instance.bound_shard.model_card.model_id\n        model_path = build_model_path(model_id)\n\n        shard_metadata = bound_instance.bound_shard\n        if not isinstance(shard_metadata, (PipelineShardMetadata, CfgShardMetadata)):\n            raise ValueError(\n                \"Expected PipelineShardMetadata or CfgShardMetadata for image generation\"\n            )\n\n        is_distributed = (\n            len(bound_instance.instance.shard_assignments.node_to_runner) > 1\n        )\n\n        if is_distributed:\n            logger.info(\"Starting distributed init for image model\")\n            group = mlx_distributed_init(bound_instance)\n        else:\n            group = None\n\n        return cls(\n            model_id=model_id,\n            local_path=model_path,\n            shard_metadata=shard_metadata,\n            group=group,\n        )\n\n    def get_steps_for_quality(self, quality: Literal[\"low\", \"medium\", \"high\"]) -> int:\n        \"\"\"Get the number of inference steps for a quality level.\"\"\"\n        return self._config.get_steps_for_quality(quality)\n\n    def generate(\n        self,\n        prompt: str,\n        height: int,\n        width: int,\n        quality: Literal[\"low\", \"medium\", \"high\"] = \"medium\",\n        seed: int = 2,\n        image_path: Path | None = None,\n        partial_images: int = 0,\n        advanced_params: AdvancedImageParams | None = None,\n    ) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:\n        if (\n            advanced_params is not None\n            and advanced_params.num_inference_steps is not None\n        ):\n            steps = advanced_params.num_inference_steps\n        else:\n            steps = self._config.get_steps_for_quality(quality)\n\n        guidance_override: float | None = None\n        if advanced_params is not None and advanced_params.guidance is not None:\n            guidance_override = advanced_params.guidance\n\n        negative_prompt: str | None = None\n        if advanced_params is not None and advanced_params.negative_prompt is not None:\n            negative_prompt = advanced_params.negative_prompt\n\n        # For edit mode: compute dimensions from input image\n        # This also stores image_paths in the adapter for encode_prompt()\n        if image_path is not None:\n            computed_dims = self._adapter.set_image_dimensions(image_path)\n            if computed_dims is not None:\n                # Override user-provided dimensions with computed ones\n                width, height = computed_dims\n\n        config = Config(\n            num_inference_steps=steps,\n            height=height,\n            width=width,\n            image_path=image_path,\n            model_config=self._adapter.model.model_config,  # pyright: ignore[reportAny]\n            guidance=guidance_override if guidance_override is not None else 4.0,\n        )\n\n        if advanced_params is not None and advanced_params.num_sync_steps is not None:\n            num_sync_steps = advanced_params.num_sync_steps\n        else:\n            num_sync_steps = self._config.num_sync_steps\n\n        for result in self._runner.generate_image(\n            runtime_config=config,\n            prompt=prompt,\n            seed=seed,\n            partial_images=partial_images,\n            guidance_override=guidance_override,\n            negative_prompt=negative_prompt,\n            num_sync_steps=num_sync_steps,\n        ):\n            if isinstance(result, tuple):\n                # Partial image: (GeneratedImage, partial_index, total_partials)\n                image, partial_idx, total_partials = result\n                yield (image, partial_idx, total_partials)\n            else:\n                logger.info(\"generated image\")\n                yield result\n\n\ndef initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:\n    return DistributedImageModel.from_bound_instance(bound_instance)\n"
  },
  {
    "path": "src/exo/worker/engines/image/generate.py",
    "content": "import base64\nimport io\nimport random\nimport tempfile\nimport time\nfrom pathlib import Path\nfrom typing import Generator, Literal\n\nimport mlx.core as mx\nfrom PIL import Image\n\nfrom exo.api.types import (\n    AdvancedImageParams,\n    ImageEditsTaskParams,\n    ImageGenerationStats,\n    ImageGenerationTaskParams,\n    ImageSize,\n)\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.worker.runner_response import (\n    ImageGenerationResponse,\n    PartialImageResponse,\n)\nfrom exo.worker.engines.image.distributed_model import DistributedImageModel\n\n\ndef parse_size(size_str: ImageSize) -> tuple[int, int]:\n    \"\"\"Parse size parameter like '1024x1024' to (width, height) tuple.\"\"\"\n    if size_str == \"auto\":\n        return (1024, 1024)\n\n    try:\n        parts = size_str.split(\"x\")\n        if len(parts) == 2:\n            width, height = int(parts[0]), int(parts[1])\n            if width > 0 and height > 0:\n                return (width, height)\n    except (ValueError, AttributeError):\n        pass\n\n    raise ValueError(\n        f\"Invalid size format: '{size_str}'. Expected 'WIDTHxHEIGHT' (e.g., '1024x1024')\"\n    )\n\n\ndef warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:\n    \"\"\"Warmup the image generator with a small image.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        # Create a small dummy image for warmup (needed for edit models)\n        dummy_image = Image.new(\"RGB\", (256, 256), color=(128, 128, 128))\n        dummy_path = Path(tmpdir) / \"warmup.png\"\n        dummy_image.save(dummy_path)\n\n        warmup_params = AdvancedImageParams(num_inference_steps=2)\n\n        for result in model.generate(\n            prompt=\"Warmup\",\n            height=256,\n            width=256,\n            quality=\"low\",\n            image_path=dummy_path,\n            advanced_params=warmup_params,\n        ):\n            if not isinstance(result, tuple):\n                return result\n    return None\n\n\ndef generate_image(\n    model: DistributedImageModel,\n    task: ImageGenerationTaskParams | ImageEditsTaskParams,\n) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:\n    \"\"\"Generate image(s), optionally yielding partial results.\n\n    When partial_images > 0 or stream=True, yields PartialImageResponse for\n    intermediate images, then ImageGenerationResponse for the final image.\n\n    Yields:\n        PartialImageResponse for intermediate images (if partial_images > 0, first image only)\n        ImageGenerationResponse for final complete images\n    \"\"\"\n    width, height = parse_size(task.size)\n    quality: Literal[\"low\", \"medium\", \"high\"] = task.quality or \"medium\"\n\n    advanced_params = task.advanced_params\n    if advanced_params is not None and advanced_params.seed is not None:\n        base_seed = advanced_params.seed\n    else:\n        base_seed = random.randint(0, 2**32 - 1)\n\n    is_bench = getattr(task, \"bench\", False)\n    num_images = task.n or 1\n\n    generation_start_time: float = 0.0\n\n    if is_bench:\n        mx.reset_peak_memory()\n        generation_start_time = time.perf_counter()\n\n    partial_images = (\n        task.partial_images\n        if task.partial_images is not None and task.stream is not None and task.stream\n        else 0\n    )\n\n    image_path: Path | None = None\n\n    with tempfile.TemporaryDirectory() as tmpdir:\n        if isinstance(task, ImageEditsTaskParams):\n            # Decode base64 image data and save to temp file\n            image_path = Path(tmpdir) / \"input.png\"\n            image_path.write_bytes(base64.b64decode(task.image_data))\n            if task.size == \"auto\":\n                with Image.open(image_path) as img:\n                    width, height = img.size\n\n        for image_num in range(num_images):\n            # Increment seed for each image to ensure unique results\n            current_seed = base_seed + image_num\n\n            for result in model.generate(\n                prompt=task.prompt,\n                height=height,\n                width=width,\n                quality=quality,\n                seed=current_seed,\n                image_path=image_path,\n                partial_images=partial_images,\n                advanced_params=advanced_params,\n            ):\n                if isinstance(result, tuple):\n                    # Partial image: (Image, partial_index, total_partials)\n                    image, partial_idx, total_partials = result\n                    buffer = io.BytesIO()\n                    image_format = task.output_format.upper()\n                    if image_format == \"JPG\":\n                        image_format = \"JPEG\"\n                    if image_format == \"JPEG\" and image.mode == \"RGBA\":\n                        image = image.convert(\"RGB\")\n                    image.save(buffer, format=image_format)\n\n                    yield PartialImageResponse(\n                        image_data=buffer.getvalue(),\n                        format=task.output_format,\n                        partial_index=partial_idx,\n                        total_partials=total_partials,\n                        image_index=image_num,\n                    )\n                else:\n                    image = result\n\n                    # Only include stats on the final image\n                    stats: ImageGenerationStats | None = None\n                    if is_bench and image_num == num_images - 1:\n                        generation_end_time = time.perf_counter()\n                        total_generation_time = (\n                            generation_end_time - generation_start_time\n                        )\n\n                        num_inference_steps = model.get_steps_for_quality(quality)\n                        total_steps = num_inference_steps * num_images\n\n                        seconds_per_step = (\n                            total_generation_time / total_steps\n                            if total_steps > 0\n                            else 0.0\n                        )\n\n                        peak_memory = Memory.from_bytes(mx.get_peak_memory())\n\n                        stats = ImageGenerationStats(\n                            seconds_per_step=seconds_per_step,\n                            total_generation_time=total_generation_time,\n                            num_inference_steps=num_inference_steps,\n                            num_images=num_images,\n                            image_width=width,\n                            image_height=height,\n                            peak_memory_usage=peak_memory,\n                        )\n\n                    buffer = io.BytesIO()\n                    image_format = task.output_format.upper()\n                    if image_format == \"JPG\":\n                        image_format = \"JPEG\"\n                    if image_format == \"JPEG\" and image.mode == \"RGBA\":\n                        image = image.convert(\"RGB\")\n                    image.save(buffer, format=image_format)\n\n                    yield ImageGenerationResponse(\n                        image_data=buffer.getvalue(),\n                        format=task.output_format,\n                        stats=stats,\n                        image_index=image_num,\n                    )\n"
  },
  {
    "path": "src/exo/worker/engines/image/models/__init__.py",
    "content": "from pathlib import Path\nfrom typing import Any, Callable\n\nfrom exo.worker.engines.image.config import ImageModelConfig\nfrom exo.worker.engines.image.models.base import ModelAdapter\nfrom exo.worker.engines.image.models.flux import (\n    FLUX_DEV_CONFIG,\n    FLUX_KONTEXT_CONFIG,\n    FLUX_SCHNELL_CONFIG,\n    FluxKontextModelAdapter,\n    FluxModelAdapter,\n)\nfrom exo.worker.engines.image.models.qwen import (\n    QWEN_IMAGE_CONFIG,\n    QWEN_IMAGE_EDIT_CONFIG,\n    QwenEditModelAdapter,\n    QwenModelAdapter,\n)\n\n__all__: list[str] = []\n\n# Type alias for adapter factory functions\n# Factory takes (config, model_id, local_path, quantize) and returns a ModelAdapter\nAdapterFactory = Callable[\n    [ImageModelConfig, str, Path, int | None], ModelAdapter[Any, Any]\n]\n\n# Registry maps model_family string to adapter factory\n_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {\n    \"flux\": FluxModelAdapter,\n    \"flux-kontext\": FluxKontextModelAdapter,\n    \"qwen-edit\": QwenEditModelAdapter,\n    \"qwen\": QwenModelAdapter,\n}\n\n# Config registry: maps model ID patterns to configs\n# Order matters: longer/more-specific patterns must come before shorter ones\n_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {\n    \"flux.1-schnell\": FLUX_SCHNELL_CONFIG,\n    \"flux.1-kontext\": FLUX_KONTEXT_CONFIG,  # Must come before \"flux.1-dev\" for pattern matching\n    \"flux.1-krea-dev\": FLUX_DEV_CONFIG,  # Must come before \"flux.1-dev\" for pattern matching\n    \"flux.1-dev\": FLUX_DEV_CONFIG,\n    \"qwen-image-edit\": QWEN_IMAGE_EDIT_CONFIG,  # Must come before \"qwen-image\" for pattern matching\n    \"qwen-image\": QWEN_IMAGE_CONFIG,\n}\n\n\ndef get_config_for_model(model_id: str) -> ImageModelConfig:\n    \"\"\"Get configuration for a model ID.\n\n    Args:\n        model_id: The model identifier (e.g., \"black-forest-labs/FLUX.1-schnell\")\n\n    Returns:\n        The model configuration\n\n    Raises:\n        ValueError: If no configuration found for model ID\n    \"\"\"\n    model_id_lower = model_id.lower()\n\n    for pattern, config in _CONFIG_REGISTRY.items():\n        if pattern in model_id_lower:\n            return config\n\n    raise ValueError(f\"No configuration found for model: {model_id}\")\n\n\ndef create_adapter_for_model(\n    config: ImageModelConfig,\n    model_id: str,\n    local_path: Path,\n    quantize: int | None = None,\n) -> ModelAdapter[Any, Any]:\n    \"\"\"Create a model adapter for the given configuration.\n\n    Args:\n        config: The model configuration\n        model_id: The model identifier\n        local_path: Path to the model weights\n        quantize: Optional quantization bits\n\n    Returns:\n        A ModelAdapter instance\n\n    Raises:\n        ValueError: If no adapter found for model family\n    \"\"\"\n    factory = _ADAPTER_REGISTRY.get(config.model_family)\n    if factory is None:\n        raise ValueError(f\"No adapter found for model family: {config.model_family}\")\n    return factory(config, model_id, local_path, quantize)\n"
  },
  {
    "path": "src/exo/worker/engines/image/models/base.py",
    "content": "from abc import ABC, abstractmethod\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any, Generic, TypeVar\n\nimport mlx.core as mx\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.common.latent_creator.latent_creator import Img2Img, LatentCreator\nfrom mflux.utils.image_util import ImageUtil\nfrom PIL import Image\n\nfrom exo.worker.engines.image.config import ImageModelConfig\n\nif TYPE_CHECKING:\n    from exo.worker.engines.image.pipeline.block_wrapper import (\n        JointBlockWrapper,\n        SingleBlockWrapper,\n    )\n\nModelT = TypeVar(\"ModelT\")\nTransformerT = TypeVar(\"TransformerT\")\n\nRotaryEmbeddings = mx.array | tuple[mx.array, mx.array]\n\n\nclass PromptData(ABC):\n    @property\n    @abstractmethod\n    def prompt_embeds(self) -> mx.array: ...\n\n    @property\n    @abstractmethod\n    def pooled_prompt_embeds(self) -> mx.array: ...\n\n    @property\n    @abstractmethod\n    def negative_prompt_embeds(self) -> mx.array | None: ...\n\n    @property\n    @abstractmethod\n    def negative_pooled_prompt_embeds(self) -> mx.array | None: ...\n\n    @abstractmethod\n    def get_encoder_hidden_states_mask(\n        self, positive: bool = True\n    ) -> mx.array | None: ...\n\n    @property\n    @abstractmethod\n    def cond_image_grid(\n        self,\n    ) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:\n        \"\"\"Conditioning image grid dimensions for edit mode.\n\n        Returns:\n            Grid dimensions (edit) or None (standard generation).\n        \"\"\"\n        ...\n\n    @property\n    @abstractmethod\n    def conditioning_latents(self) -> mx.array | None:\n        \"\"\"Conditioning latents for edit mode.\n\n        Returns:\n            Conditioning latents array for image editing, None for standard generation.\n        \"\"\"\n        ...\n\n    @property\n    @abstractmethod\n    def kontext_image_ids(self) -> mx.array | None:\n        \"\"\"Kontext-style position IDs for image conditioning.\n\n        For FLUX.1-Kontext models, returns position IDs with first_coord=1\n        to distinguish conditioning tokens from generation tokens (first_coord=0).\n\n        Returns:\n            Position IDs array [1, seq_len, 3] for Kontext, None for other models.\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def get_batched_cfg_data(\n        self,\n    ) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:\n        \"\"\"Get embeddings for CFG with batch_size=2.\n\n        Combines positive and negative embeddings into batched tensors for\n        a single forward pass. Pads shorter sequences to max length. Attention\n        mask is used to mask padding.\n\n        Returns:\n            None if model doesn't support CFG, otherwise tuple of:\n            - batched_embeds: [2, max_seq, hidden] (positive then negative)\n            - batched_mask: [2, max_seq] attention mask\n            - batched_pooled: [2, hidden] pooled embeddings or None\n            - conditioning_latents: [2, latent_seq, latent_dim] or None\n            TODO(ciaran): type this\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def get_cfg_branch_data(\n        self, positive: bool\n    ) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:\n        \"\"\"Get embeddings for a single CFG branch (positive or negative).\n\n        Used for sequential CFG and CFG parallel modes where we process\n        one branch at a time instead of batching.\n\n        Args:\n            positive: True for positive prompt, False for negative prompt\n\n        Returns:\n            Tuple of:\n            - embeds: [1, seq, hidden] prompt embeddings\n            - mask: [1, seq] attention mask or None\n            - pooled: [1, hidden] pooled embeddings or None\n            - conditioning_latents: [1, latent_seq, latent_dim] or None\n        \"\"\"\n        ...\n\n\nclass ModelAdapter(ABC, Generic[ModelT, TransformerT]):\n    _config: ImageModelConfig\n    _model: ModelT\n    _transformer: TransformerT\n\n    @property\n    def config(self) -> ImageModelConfig:\n        return self._config\n\n    @property\n    def model(self) -> ModelT:\n        return self._model\n\n    @property\n    def transformer(self) -> TransformerT:\n        return self._transformer\n\n    @property\n    @abstractmethod\n    def hidden_dim(self) -> int: ...\n\n    @property\n    @abstractmethod\n    def needs_cfg(self) -> bool:\n        \"\"\"Whether this model uses classifier-free guidance.\"\"\"\n        ...\n\n    @abstractmethod\n    def _get_latent_creator(self) -> type: ...\n\n    @abstractmethod\n    def get_joint_block_wrappers(\n        self,\n        text_seq_len: int,\n        encoder_hidden_states_mask: mx.array | None = None,\n    ) -> list[\"JointBlockWrapper[Any]\"]:\n        \"\"\"Create wrapped joint transformer blocks with pipefusion support.\n\n        Args:\n            text_seq_len: Number of text tokens (constant for generation)\n            encoder_hidden_states_mask: Attention mask for text (Qwen only)\n\n        Returns:\n            List of wrapped joint blocks ready for pipefusion\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def get_single_block_wrappers(\n        self,\n        text_seq_len: int,\n    ) -> list[\"SingleBlockWrapper[Any]\"]:\n        \"\"\"Create wrapped single transformer blocks with pipefusion support.\n\n        Args:\n            text_seq_len: Number of text tokens (constant for generation)\n\n        Returns:\n            List of wrapped single blocks ready for pipefusion\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def slice_transformer_blocks(\n        self,\n        start_layer: int,\n        end_layer: int,\n    ):\n        \"\"\"Remove transformer blocks outside the assigned range.\n\n        This should be called BEFORE mx.eval() to avoid loading unused weights\n        in distributed mode.\n\n        Args:\n            start_layer: First layer index (inclusive) assigned to this node\n            end_layer: Last layer index (exclusive) assigned to this node\n        \"\"\"\n        ...\n\n    def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:\n        \"\"\"Default implementation: no dimension computation needed.\n\n        Override in edit adapters to compute dimensions from input image.\n        TODO(ciaran): this is a hack\n\n        Returns:\n            None (use user-specified dimensions)\n        \"\"\"\n        return None\n\n    def create_latents(self, seed: int, runtime_config: Config) -> mx.array:\n        \"\"\"Create initial latents. Uses model-specific latent creator.\"\"\"\n        model: Any = self.model\n        return LatentCreator.create_for_txt2img_or_img2img(\n            seed=seed,\n            height=runtime_config.height,\n            width=runtime_config.width,\n            img2img=Img2Img(\n                vae=model.vae,  # pyright: ignore[reportAny]\n                latent_creator=self._get_latent_creator(),\n                sigmas=runtime_config.scheduler.sigmas,  # pyright: ignore[reportAny]\n                init_time_step=runtime_config.init_time_step,\n                image_path=runtime_config.image_path,\n            ),\n        )\n\n    def decode_latents(\n        self,\n        latents: mx.array,\n        runtime_config: Config,\n        seed: int,\n        prompt: str,\n    ) -> Image.Image:\n        model: Any = self.model  # Allow attribute access on model\n        latents = self._get_latent_creator().unpack_latents(  # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]\n            latents=latents,\n            height=runtime_config.height,\n            width=runtime_config.width,\n        )\n        decoded = model.vae.decode(latents)  # pyright: ignore[reportAny]\n        # TODO(ciaran):\n        # from mflux.models.common.vae.vae_util import VAEUtil\n        # VAEUtil.decode(vae=model.vae, latents=latents, tiling_config=self.tiling_config)\n        generated_image = ImageUtil.to_image(\n            decoded_latents=decoded,  # pyright: ignore[reportAny]\n            config=runtime_config,\n            seed=seed,\n            prompt=prompt,\n            quantization=model.bits,  # pyright: ignore[reportAny]\n            lora_paths=model.lora_paths,  # pyright: ignore[reportAny]\n            lora_scales=model.lora_scales,  # pyright: ignore[reportAny]\n            image_path=runtime_config.image_path,\n            image_strength=runtime_config.image_strength,\n            generation_time=0,\n        )\n        return generated_image.image\n\n    @abstractmethod\n    def encode_prompt(\n        self, prompt: str, negative_prompt: str | None = None\n    ) -> \"PromptData\": ...\n\n    @abstractmethod\n    def compute_embeddings(\n        self,\n        hidden_states: mx.array,\n        prompt_embeds: mx.array,\n    ) -> tuple[mx.array, mx.array]: ...\n\n    @abstractmethod\n    def compute_text_embeddings(\n        self,\n        t: int,\n        runtime_config: Config,\n        pooled_prompt_embeds: mx.array | None = None,\n        hidden_states: mx.array | None = None,\n    ) -> mx.array: ...\n\n    @abstractmethod\n    def compute_rotary_embeddings(\n        self,\n        prompt_embeds: mx.array,\n        runtime_config: Config,\n        encoder_hidden_states_mask: mx.array | None = None,\n        cond_image_grid: tuple[int, int, int]\n        | list[tuple[int, int, int]]\n        | None = None,\n        kontext_image_ids: mx.array | None = None,\n    ) -> RotaryEmbeddings: ...\n\n    def merge_streams(\n        self,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n    ) -> mx.array:\n        return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)\n\n    @abstractmethod\n    def apply_guidance(\n        self,\n        noise_positive: mx.array,\n        noise_negative: mx.array,\n        guidance_scale: float,\n    ) -> mx.array:\n        \"\"\"Apply classifier-free guidance to combine positive/negative predictions.\n\n        Only called when needs_cfg is True.\n\n        Args:\n            noise_positive: Noise prediction from positive prompt\n            noise_negative: Noise prediction from negative prompt\n            guidance_scale: Guidance strength\n\n        Returns:\n            Guided noise prediction\n        \"\"\"\n        ...\n\n    def final_projection(\n        self,\n        hidden_states: mx.array,\n        text_embeddings: mx.array,\n    ) -> mx.array:\n        transformer: Any = self.transformer\n        hidden_states = transformer.norm_out(hidden_states, text_embeddings)  # pyright: ignore[reportAny]\n        return transformer.proj_out(hidden_states)  # pyright: ignore[reportAny]\n"
  },
  {
    "path": "src/exo/worker/engines/image/models/flux/__init__.py",
    "content": "from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter\nfrom exo.worker.engines.image.models.flux.config import (\n    FLUX_DEV_CONFIG,\n    FLUX_KONTEXT_CONFIG,\n    FLUX_SCHNELL_CONFIG,\n)\nfrom exo.worker.engines.image.models.flux.kontext_adapter import (\n    FluxKontextModelAdapter,\n)\n\n__all__ = [\n    \"FluxModelAdapter\",\n    \"FluxKontextModelAdapter\",\n    \"FLUX_DEV_CONFIG\",\n    \"FLUX_KONTEXT_CONFIG\",\n    \"FLUX_SCHNELL_CONFIG\",\n]\n"
  },
  {
    "path": "src/exo/worker/engines/image/models/flux/adapter.py",
    "content": "from pathlib import Path\nfrom typing import Any\n\nimport mlx.core as mx\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.common.config.model_config import ModelConfig\nfrom mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator\nfrom mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder\nfrom mflux.models.flux.model.flux_transformer.transformer import Transformer\nfrom mflux.models.flux.variants.txt2img.flux import Flux1\n\nfrom exo.worker.engines.image.config import ImageModelConfig\nfrom exo.worker.engines.image.models.base import (\n    ModelAdapter,\n    PromptData,\n    RotaryEmbeddings,\n)\nfrom exo.worker.engines.image.models.flux.wrappers import (\n    FluxJointBlockWrapper,\n    FluxSingleBlockWrapper,\n)\nfrom exo.worker.engines.image.pipeline.block_wrapper import (\n    JointBlockWrapper,\n    SingleBlockWrapper,\n)\n\n\nclass FluxPromptData(PromptData):\n    def __init__(self, prompt_embeds: mx.array, pooled_prompt_embeds: mx.array):\n        self._prompt_embeds = prompt_embeds\n        self._pooled_prompt_embeds = pooled_prompt_embeds\n\n    @property\n    def prompt_embeds(self) -> mx.array:\n        return self._prompt_embeds\n\n    @property\n    def pooled_prompt_embeds(self) -> mx.array:\n        return self._pooled_prompt_embeds\n\n    @property\n    def negative_prompt_embeds(self) -> mx.array | None:\n        return None\n\n    @property\n    def negative_pooled_prompt_embeds(self) -> mx.array | None:\n        return None\n\n    def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:\n        return None\n\n    @property\n    def cond_image_grid(\n        self,\n    ) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:\n        return None\n\n    @property\n    def conditioning_latents(self) -> mx.array | None:\n        return None\n\n    @property\n    def kontext_image_ids(self) -> mx.array | None:\n        return None\n\n    def get_batched_cfg_data(\n        self,\n    ) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:\n        return None\n\n    def get_cfg_branch_data(\n        self, positive: bool\n    ) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:\n        \"\"\"Flux doesn't use CFG, but we return positive data for compatibility.\"\"\"\n        return (self._prompt_embeds, None, self._pooled_prompt_embeds, None)\n\n\nclass FluxModelAdapter(ModelAdapter[Flux1, Transformer]):\n    def __init__(\n        self,\n        config: ImageModelConfig,\n        model_id: str,\n        local_path: Path,\n        quantize: int | None = None,\n    ):\n        self._config = config\n        self._model = Flux1(\n            model_config=ModelConfig.from_name(model_name=model_id, base_model=None),\n            model_path=str(local_path),\n            quantize=quantize,\n        )\n        self._transformer = self._model.transformer\n\n    @property\n    def hidden_dim(self) -> int:\n        return self._transformer.x_embedder.weight.shape[0]\n\n    @property\n    def needs_cfg(self) -> bool:\n        return False\n\n    def _get_latent_creator(self) -> type:\n        return FluxLatentCreator\n\n    def get_joint_block_wrappers(\n        self,\n        text_seq_len: int,\n        encoder_hidden_states_mask: mx.array | None = None,\n    ) -> list[JointBlockWrapper[Any]]:\n        \"\"\"Create wrapped joint blocks for Flux.\"\"\"\n        return [\n            FluxJointBlockWrapper(block, text_seq_len)\n            for block in self._transformer.transformer_blocks\n        ]\n\n    def get_single_block_wrappers(\n        self,\n        text_seq_len: int,\n    ) -> list[SingleBlockWrapper[Any]]:\n        \"\"\"Create wrapped single blocks for Flux.\"\"\"\n        return [\n            FluxSingleBlockWrapper(block, text_seq_len)\n            for block in self._transformer.single_transformer_blocks\n        ]\n\n    def slice_transformer_blocks(\n        self,\n        start_layer: int,\n        end_layer: int,\n    ):\n        all_joint = list(self._transformer.transformer_blocks)\n        all_single = list(self._transformer.single_transformer_blocks)\n        total_joint_blocks = len(all_joint)\n        if end_layer <= total_joint_blocks:\n            # All assigned are joint blocks\n            joint_start, joint_end = start_layer, end_layer\n            single_start, single_end = 0, 0\n        elif start_layer >= total_joint_blocks:\n            # All assigned are single blocks\n            joint_start, joint_end = 0, 0\n            single_start = start_layer - total_joint_blocks\n            single_end = end_layer - total_joint_blocks\n        else:\n            # Spans both joint and single\n            joint_start, joint_end = start_layer, total_joint_blocks\n            single_start = 0\n            single_end = end_layer - total_joint_blocks\n\n        self._transformer.transformer_blocks = all_joint[joint_start:joint_end]\n\n        self._transformer.single_transformer_blocks = all_single[\n            single_start:single_end\n        ]\n\n    def encode_prompt(\n        self, prompt: str, negative_prompt: str | None = None\n    ) -> FluxPromptData:\n        del negative_prompt\n\n        assert isinstance(self.model.prompt_cache, dict)\n        assert isinstance(self.model.tokenizers, dict)\n\n        prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(\n            prompt=prompt,\n            prompt_cache=self.model.prompt_cache,\n            t5_tokenizer=self.model.tokenizers[\"t5\"],  # pyright: ignore[reportAny]\n            clip_tokenizer=self.model.tokenizers[\"clip\"],  # pyright: ignore[reportAny]\n            t5_text_encoder=self.model.t5_text_encoder,\n            clip_text_encoder=self.model.clip_text_encoder,\n        )\n        return FluxPromptData(\n            prompt_embeds=prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n        )\n\n    def compute_embeddings(\n        self,\n        hidden_states: mx.array,\n        prompt_embeds: mx.array,\n    ) -> tuple[mx.array, mx.array]:\n        embedded_hidden = self._transformer.x_embedder(hidden_states)\n        embedded_encoder = self._transformer.context_embedder(prompt_embeds)\n        return embedded_hidden, embedded_encoder\n\n    def compute_text_embeddings(\n        self,\n        t: int,\n        runtime_config: Config,\n        pooled_prompt_embeds: mx.array | None = None,\n        hidden_states: mx.array | None = None,  # Ignored by Flux\n    ) -> mx.array:\n        if pooled_prompt_embeds is None:\n            raise ValueError(\n                \"pooled_prompt_embeds is required for Flux text embeddings\"\n            )\n\n        # hidden_states is ignored - Flux uses pooled_prompt_embeds instead\n        return Transformer.compute_text_embeddings(\n            t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config\n        )\n\n    def compute_rotary_embeddings(\n        self,\n        prompt_embeds: mx.array,\n        runtime_config: Config,\n        encoder_hidden_states_mask: mx.array | None = None,\n        cond_image_grid: tuple[int, int, int]\n        | list[tuple[int, int, int]]\n        | None = None,\n        kontext_image_ids: mx.array | None = None,\n    ) -> RotaryEmbeddings:\n        return Transformer.compute_rotary_embeddings(\n            prompt_embeds,\n            self._transformer.pos_embed,\n            runtime_config,\n            kontext_image_ids,\n        )\n\n    def apply_guidance(\n        self,\n        noise_positive: mx.array,\n        noise_negative: mx.array,\n        guidance_scale: float,\n    ) -> mx.array:\n        raise NotImplementedError(\"Flux does not use classifier-free guidance\")\n"
  },
  {
    "path": "src/exo/worker/engines/image/models/flux/config.py",
    "content": "from exo.worker.engines.image.config import (\n    BlockType,\n    ImageModelConfig,\n    TransformerBlockConfig,\n)\n\nFLUX_SCHNELL_CONFIG = ImageModelConfig(\n    model_family=\"flux\",\n    block_configs=(\n        TransformerBlockConfig(\n            block_type=BlockType.JOINT, count=19, has_separate_text_output=True\n        ),\n        TransformerBlockConfig(\n            block_type=BlockType.SINGLE, count=38, has_separate_text_output=False\n        ),\n    ),\n    default_steps={\"low\": 1, \"medium\": 2, \"high\": 4},\n    num_sync_steps=1,\n)\n\n\nFLUX_DEV_CONFIG = ImageModelConfig(\n    model_family=\"flux\",\n    block_configs=(\n        TransformerBlockConfig(\n            block_type=BlockType.JOINT, count=19, has_separate_text_output=True\n        ),\n        TransformerBlockConfig(\n            block_type=BlockType.SINGLE, count=38, has_separate_text_output=False\n        ),\n    ),\n    default_steps={\"low\": 10, \"medium\": 25, \"high\": 50},\n    num_sync_steps=4,\n)\n\n\nFLUX_KONTEXT_CONFIG = ImageModelConfig(\n    model_family=\"flux-kontext\",\n    block_configs=(\n        TransformerBlockConfig(\n            block_type=BlockType.JOINT, count=19, has_separate_text_output=True\n        ),\n        TransformerBlockConfig(\n            block_type=BlockType.SINGLE, count=38, has_separate_text_output=False\n        ),\n    ),\n    default_steps={\"low\": 10, \"medium\": 25, \"high\": 50},\n    num_sync_steps=4,\n    guidance_scale=4.0,\n)\n"
  },
  {
    "path": "src/exo/worker/engines/image/models/flux/kontext_adapter.py",
    "content": "import math\nfrom pathlib import Path\nfrom typing import Any, final\n\nimport mlx.core as mx\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.common.config.model_config import ModelConfig\nfrom mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator\nfrom mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder\nfrom mflux.models.flux.model.flux_transformer.transformer import Transformer\nfrom mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext\nfrom mflux.models.flux.variants.kontext.kontext_util import KontextUtil\n\nfrom exo.worker.engines.image.config import ImageModelConfig\nfrom exo.worker.engines.image.models.base import (\n    ModelAdapter,\n    PromptData,\n    RotaryEmbeddings,\n)\nfrom exo.worker.engines.image.models.flux.wrappers import (\n    FluxJointBlockWrapper,\n    FluxSingleBlockWrapper,\n)\nfrom exo.worker.engines.image.pipeline.block_wrapper import (\n    JointBlockWrapper,\n    SingleBlockWrapper,\n)\n\n\n@final\nclass FluxKontextPromptData(PromptData):\n    \"\"\"Prompt data for FLUX.1-Kontext image editing.\n\n    Stores text embeddings along with conditioning latents and position IDs\n    for the input image.\n    \"\"\"\n\n    def __init__(\n        self,\n        prompt_embeds: mx.array,\n        pooled_prompt_embeds: mx.array,\n        conditioning_latents: mx.array,\n        kontext_image_ids: mx.array,\n    ):\n        self._prompt_embeds = prompt_embeds\n        self._pooled_prompt_embeds = pooled_prompt_embeds\n        self._conditioning_latents = conditioning_latents\n        self._kontext_image_ids = kontext_image_ids\n\n    @property\n    def prompt_embeds(self) -> mx.array:\n        return self._prompt_embeds\n\n    @property\n    def pooled_prompt_embeds(self) -> mx.array:\n        return self._pooled_prompt_embeds\n\n    @property\n    def negative_prompt_embeds(self) -> mx.array | None:\n        return None\n\n    @property\n    def negative_pooled_prompt_embeds(self) -> mx.array | None:\n        return None\n\n    def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:\n        return None\n\n    @property\n    def cond_image_grid(\n        self,\n    ) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:\n        return None\n\n    @property\n    def conditioning_latents(self) -> mx.array | None:\n        \"\"\"VAE-encoded input image latents for Kontext conditioning.\"\"\"\n        return self._conditioning_latents\n\n    @property\n    def kontext_image_ids(self) -> mx.array | None:\n        \"\"\"Position IDs for Kontext conditioning (first_coord=1).\"\"\"\n        return self._kontext_image_ids\n\n    def get_cfg_branch_data(\n        self, positive: bool\n    ) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:\n        \"\"\"Kontext doesn't use CFG, but we return positive data for compatibility.\"\"\"\n        return (\n            self._prompt_embeds,\n            None,\n            self._pooled_prompt_embeds,\n            self._conditioning_latents,\n        )\n\n    def get_batched_cfg_data(\n        self,\n    ) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:\n        # Kontext doesn't use CFG\n        return None\n\n\n@final\nclass FluxKontextModelAdapter(ModelAdapter[Flux1Kontext, Transformer]):\n    \"\"\"Adapter for FLUX.1-Kontext image editing model.\n\n    Key differences from standard FluxModelAdapter:\n    - Takes an input image and computes output dimensions from it\n    - Creates conditioning latents from the input image via VAE\n    - Creates special position IDs (kontext_image_ids) for conditioning tokens\n    - Creates pure noise latents (not img2img blending)\n    \"\"\"\n\n    def __init__(\n        self,\n        config: ImageModelConfig,\n        model_id: str,\n        local_path: Path,\n        quantize: int | None = None,\n    ):\n        self._config = config\n        self._model = Flux1Kontext(\n            model_config=ModelConfig.from_name(model_name=model_id, base_model=None),\n            model_path=str(local_path),\n            quantize=quantize,\n        )\n        self._transformer = self._model.transformer\n\n        # Stores image path and computed dimensions after set_image_dimensions\n        self._image_path: str | None = None\n        self._output_height: int | None = None\n        self._output_width: int | None = None\n\n    @property\n    def hidden_dim(self) -> int:\n        return self._transformer.x_embedder.weight.shape[0]\n\n    @property\n    def needs_cfg(self) -> bool:\n        return False\n\n    def _get_latent_creator(self) -> type:\n        return FluxLatentCreator\n\n    def get_joint_block_wrappers(\n        self,\n        text_seq_len: int,\n        encoder_hidden_states_mask: mx.array | None = None,\n    ) -> list[JointBlockWrapper[Any]]:\n        \"\"\"Create wrapped joint blocks for Flux Kontext.\"\"\"\n        return [\n            FluxJointBlockWrapper(block, text_seq_len)\n            for block in self._transformer.transformer_blocks\n        ]\n\n    def get_single_block_wrappers(\n        self,\n        text_seq_len: int,\n    ) -> list[SingleBlockWrapper[Any]]:\n        \"\"\"Create wrapped single blocks for Flux Kontext.\"\"\"\n        return [\n            FluxSingleBlockWrapper(block, text_seq_len)\n            for block in self._transformer.single_transformer_blocks\n        ]\n\n    def slice_transformer_blocks(\n        self,\n        start_layer: int,\n        end_layer: int,\n    ):\n        all_joint = list(self._transformer.transformer_blocks)\n        all_single = list(self._transformer.single_transformer_blocks)\n        total_joint_blocks = len(all_joint)\n        if end_layer <= total_joint_blocks:\n            # All assigned are joint blocks\n            joint_start, joint_end = start_layer, end_layer\n            single_start, single_end = 0, 0\n        elif start_layer >= total_joint_blocks:\n            # All assigned are single blocks\n            joint_start, joint_end = 0, 0\n            single_start = start_layer - total_joint_blocks\n            single_end = end_layer - total_joint_blocks\n        else:\n            # Spans both joint and single\n            joint_start, joint_end = start_layer, total_joint_blocks\n            single_start = 0\n            single_end = end_layer - total_joint_blocks\n\n        self._transformer.transformer_blocks = all_joint[joint_start:joint_end]\n        self._transformer.single_transformer_blocks = all_single[\n            single_start:single_end\n        ]\n\n    def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:\n        \"\"\"Compute and store dimensions from input image.\n\n        Also stores image_path for use in encode_prompt().\n\n        Args:\n            image_path: Path to the input image\n\n        Returns:\n            (output_width, output_height) for runtime config\n        \"\"\"\n        from mflux.utils.image_util import ImageUtil\n\n        pil_image = ImageUtil.load_image(str(image_path)).convert(\"RGB\")\n        image_size = pil_image.size\n\n        # Compute output dimensions from input image aspect ratio\n        # Target area of 1024x1024 = ~1M pixels\n        target_area = 1024 * 1024\n        ratio = image_size[0] / image_size[1]\n        output_width = math.sqrt(target_area * ratio)\n        output_height = output_width / ratio\n        output_width = round(output_width / 32) * 32\n        output_height = round(output_height / 32) * 32\n\n        # Ensure multiple of 16 for VAE\n        vae_scale_factor = 8\n        multiple_of = vae_scale_factor * 2\n        output_width = output_width // multiple_of * multiple_of\n        output_height = output_height // multiple_of * multiple_of\n\n        self._image_path = str(image_path)\n        self._output_width = int(output_width)\n        self._output_height = int(output_height)\n\n        return self._output_width, self._output_height\n\n    def create_latents(self, seed: int, runtime_config: Config) -> mx.array:\n        \"\"\"Create initial noise latents for Kontext.\n\n        Unlike standard img2img which blends noise with encoded input,\n        Kontext uses pure noise latents. The input image is provided\n        separately as conditioning.\n        \"\"\"\n        return FluxLatentCreator.create_noise(\n            seed=seed,\n            height=runtime_config.height,\n            width=runtime_config.width,\n        )\n\n    def encode_prompt(\n        self, prompt: str, negative_prompt: str | None = None\n    ) -> FluxKontextPromptData:\n        \"\"\"Encode prompt and create conditioning from stored input image.\n\n        Must call set_image_dimensions() before this method.\n\n        Args:\n            prompt: Text prompt for editing\n            negative_prompt: Ignored (Kontext doesn't use CFG)\n\n        Returns:\n            FluxKontextPromptData with text embeddings and image conditioning\n        \"\"\"\n        del negative_prompt  # Kontext doesn't support negative prompts or CFG\n\n        if (\n            self._image_path is None\n            or self._output_height is None\n            or self._output_width is None\n        ):\n            raise RuntimeError(\n                \"set_image_dimensions() must be called before encode_prompt() \"\n                \"for FluxKontextModelAdapter\"\n            )\n\n        assert isinstance(self.model.prompt_cache, dict)\n        assert isinstance(self.model.tokenizers, dict)\n\n        # Encode text prompt\n        prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(\n            prompt=prompt,\n            prompt_cache=self.model.prompt_cache,\n            t5_tokenizer=self.model.tokenizers[\"t5\"],  # pyright: ignore[reportAny]\n            clip_tokenizer=self.model.tokenizers[\"clip\"],  # pyright: ignore[reportAny]\n            t5_text_encoder=self.model.t5_text_encoder,\n            clip_text_encoder=self.model.clip_text_encoder,\n        )\n\n        # Create conditioning latents from input image\n        conditioning_latents, kontext_image_ids = (\n            KontextUtil.create_image_conditioning_latents(\n                vae=self.model.vae,\n                height=self._output_height,\n                width=self._output_width,\n                image_path=self._image_path,\n            )\n        )\n\n        return FluxKontextPromptData(\n            prompt_embeds=prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            conditioning_latents=conditioning_latents,\n            kontext_image_ids=kontext_image_ids,\n        )\n\n    def compute_embeddings(\n        self,\n        hidden_states: mx.array,\n        prompt_embeds: mx.array,\n    ) -> tuple[mx.array, mx.array]:\n        embedded_hidden = self._transformer.x_embedder(hidden_states)\n        embedded_encoder = self._transformer.context_embedder(prompt_embeds)\n        return embedded_hidden, embedded_encoder\n\n    def compute_text_embeddings(\n        self,\n        t: int,\n        runtime_config: Config,\n        pooled_prompt_embeds: mx.array | None = None,\n        hidden_states: mx.array | None = None,\n    ) -> mx.array:\n        if pooled_prompt_embeds is None:\n            raise ValueError(\n                \"pooled_prompt_embeds is required for Flux Kontext text embeddings\"\n            )\n\n        return Transformer.compute_text_embeddings(\n            t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config\n        )\n\n    def compute_rotary_embeddings(\n        self,\n        prompt_embeds: mx.array,\n        runtime_config: Config,\n        encoder_hidden_states_mask: mx.array | None = None,\n        cond_image_grid: tuple[int, int, int]\n        | list[tuple[int, int, int]]\n        | None = None,\n        kontext_image_ids: mx.array | None = None,\n    ) -> RotaryEmbeddings:\n        return Transformer.compute_rotary_embeddings(\n            prompt_embeds,\n            self._transformer.pos_embed,\n            runtime_config,\n            kontext_image_ids,\n        )\n\n    def apply_guidance(\n        self,\n        noise_positive: mx.array,\n        noise_negative: mx.array,\n        guidance_scale: float,\n    ) -> mx.array:\n        raise NotImplementedError(\"Flux Kontext does not use classifier-free guidance\")\n"
  },
  {
    "path": "src/exo/worker/engines/image/models/flux/wrappers.py",
    "content": "from typing import final\n\nimport mlx.core as mx\nfrom mflux.models.flux.model.flux_transformer.common.attention_utils import (\n    AttentionUtils,\n)\nfrom mflux.models.flux.model.flux_transformer.joint_transformer_block import (\n    JointTransformerBlock,\n)\nfrom mflux.models.flux.model.flux_transformer.single_transformer_block import (\n    SingleTransformerBlock,\n)\nfrom pydantic import BaseModel, ConfigDict\n\nfrom exo.worker.engines.image.models.base import RotaryEmbeddings\nfrom exo.worker.engines.image.pipeline.block_wrapper import (\n    JointBlockWrapper,\n    SingleBlockWrapper,\n)\n\n\n@final\nclass FluxModulationParams(BaseModel):\n    model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True)\n\n    gate_msa: mx.array\n    shift_mlp: mx.array\n    scale_mlp: mx.array\n    gate_mlp: mx.array\n\n\n@final\nclass FluxNormGateState(BaseModel):\n    model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True)\n\n    norm_hidden: mx.array\n    gate: mx.array\n\n\nclass FluxJointBlockWrapper(JointBlockWrapper[JointTransformerBlock]):\n    def __init__(self, block: JointTransformerBlock, text_seq_len: int):\n        super().__init__(block, text_seq_len)\n        self._num_heads = block.attn.num_heads\n        self._head_dim = block.attn.head_dimension\n\n        # Intermediate state stored between _compute_qkv and _apply_output\n        self._hidden_mod: FluxModulationParams | None = None\n        self._context_mod: FluxModulationParams | None = None\n\n    def _compute_qkv(\n        self,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        text_embeddings: mx.array,\n        rotary_embeddings: RotaryEmbeddings,\n        patch_mode: bool = False,\n    ) -> tuple[mx.array, mx.array, mx.array]:\n        assert isinstance(rotary_embeddings, mx.array)\n\n        attn = self.block.attn\n\n        (\n            norm_hidden,\n            gate_msa,\n            shift_mlp,\n            scale_mlp,\n            gate_mlp,\n        ) = self.block.norm1(\n            hidden_states=hidden_states,\n            text_embeddings=text_embeddings,\n        )\n        self._hidden_mod = FluxModulationParams(\n            gate_msa=gate_msa,\n            shift_mlp=shift_mlp,\n            scale_mlp=scale_mlp,\n            gate_mlp=gate_mlp,\n        )\n\n        (\n            norm_encoder,\n            c_gate_msa,\n            c_shift_mlp,\n            c_scale_mlp,\n            c_gate_mlp,\n        ) = self.block.norm1_context(\n            hidden_states=encoder_hidden_states,\n            text_embeddings=text_embeddings,\n        )\n        self._context_mod = FluxModulationParams(\n            gate_msa=c_gate_msa,\n            shift_mlp=c_shift_mlp,\n            scale_mlp=c_scale_mlp,\n            gate_mlp=c_gate_mlp,\n        )\n\n        img_query, img_key, img_value = AttentionUtils.process_qkv(\n            hidden_states=norm_hidden,\n            to_q=attn.to_q,\n            to_k=attn.to_k,\n            to_v=attn.to_v,\n            norm_q=attn.norm_q,\n            norm_k=attn.norm_k,\n            num_heads=self._num_heads,\n            head_dim=self._head_dim,\n        )\n\n        txt_query, txt_key, txt_value = AttentionUtils.process_qkv(\n            hidden_states=norm_encoder,\n            to_q=attn.add_q_proj,\n            to_k=attn.add_k_proj,\n            to_v=attn.add_v_proj,\n            norm_q=attn.norm_added_q,\n            norm_k=attn.norm_added_k,\n            num_heads=self._num_heads,\n            head_dim=self._head_dim,\n        )\n\n        query = mx.concatenate([txt_query, img_query], axis=2)\n        key = mx.concatenate([txt_key, img_key], axis=2)\n        value = mx.concatenate([txt_value, img_value], axis=2)\n\n        if patch_mode:\n            text_rope = rotary_embeddings[:, :, : self._text_seq_len, ...]\n            patch_img_rope = rotary_embeddings[\n                :,\n                :,\n                self._text_seq_len + self._patch_start : self._text_seq_len\n                + self._patch_end,\n                ...,\n            ]\n            rope = mx.concatenate([text_rope, patch_img_rope], axis=2)\n        else:\n            rope = rotary_embeddings\n\n        query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=rope)\n\n        return query, key, value\n\n    def _compute_attention(\n        self, query: mx.array, key: mx.array, value: mx.array\n    ) -> mx.array:\n        batch_size = query.shape[0]\n        return AttentionUtils.compute_attention(\n            query=query,\n            key=key,\n            value=value,\n            batch_size=batch_size,\n            num_heads=self._num_heads,\n            head_dim=self._head_dim,\n        )\n\n    def _apply_output(\n        self,\n        attn_out: mx.array,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        text_embeddings: mx.array,\n    ) -> tuple[mx.array, mx.array]:\n        attn = self.block.attn\n\n        context_attn_output = attn_out[:, : self._text_seq_len, :]\n        hidden_attn_output = attn_out[:, self._text_seq_len :, :]\n\n        hidden_attn_output = attn.to_out[0](hidden_attn_output)  # pyright: ignore[reportAny]\n        context_attn_output = attn.to_add_out(context_attn_output)\n\n        assert self._hidden_mod is not None\n        assert self._context_mod is not None\n\n        hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(\n            hidden_states=hidden_states,\n            attn_output=hidden_attn_output,  # pyright: ignore[reportAny]\n            gate_mlp=self._hidden_mod.gate_mlp,\n            gate_msa=self._hidden_mod.gate_msa,\n            scale_mlp=self._hidden_mod.scale_mlp,\n            shift_mlp=self._hidden_mod.shift_mlp,\n            norm_layer=self.block.norm2,\n            ff_layer=self.block.ff,\n        )\n        encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(\n            hidden_states=encoder_hidden_states,\n            attn_output=context_attn_output,\n            gate_mlp=self._context_mod.gate_mlp,\n            gate_msa=self._context_mod.gate_msa,\n            scale_mlp=self._context_mod.scale_mlp,\n            shift_mlp=self._context_mod.shift_mlp,\n            norm_layer=self.block.norm2_context,\n            ff_layer=self.block.ff_context,\n        )\n\n        return encoder_hidden_states, hidden_states\n\n\nclass FluxSingleBlockWrapper(SingleBlockWrapper[SingleTransformerBlock]):\n    \"\"\"Flux-specific single block wrapper with pipefusion support.\"\"\"\n\n    def __init__(self, block: SingleTransformerBlock, text_seq_len: int):\n        super().__init__(block, text_seq_len)\n        self._num_heads = block.attn.num_heads\n        self._head_dim = block.attn.head_dimension\n\n        # Intermediate state stored between _compute_qkv and _apply_output\n        self._norm_state: FluxNormGateState | None = None\n\n    def _compute_qkv(\n        self,\n        hidden_states: mx.array,\n        text_embeddings: mx.array,\n        rotary_embeddings: RotaryEmbeddings,\n        patch_mode: bool = False,\n    ) -> tuple[mx.array, mx.array, mx.array]:\n        assert isinstance(rotary_embeddings, mx.array)\n\n        attn = self.block.attn\n\n        norm_hidden, gate = self.block.norm(\n            hidden_states=hidden_states,\n            text_embeddings=text_embeddings,\n        )\n        self._norm_state = FluxNormGateState(norm_hidden=norm_hidden, gate=gate)\n\n        query, key, value = AttentionUtils.process_qkv(\n            hidden_states=norm_hidden,\n            to_q=attn.to_q,\n            to_k=attn.to_k,\n            to_v=attn.to_v,\n            norm_q=attn.norm_q,\n            norm_k=attn.norm_k,\n            num_heads=self._num_heads,\n            head_dim=self._head_dim,\n        )\n\n        if patch_mode:\n            text_rope = rotary_embeddings[:, :, : self._text_seq_len, ...]\n            patch_img_rope = rotary_embeddings[\n                :,\n                :,\n                self._text_seq_len + self._patch_start : self._text_seq_len\n                + self._patch_end,\n                ...,\n            ]\n            rope = mx.concatenate([text_rope, patch_img_rope], axis=2)\n        else:\n            rope = rotary_embeddings\n\n        query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=rope)\n\n        return query, key, value\n\n    def _compute_attention(\n        self, query: mx.array, key: mx.array, value: mx.array\n    ) -> mx.array:\n        batch_size = query.shape[0]\n        return AttentionUtils.compute_attention(\n            query=query,\n            key=key,\n            value=value,\n            batch_size=batch_size,\n            num_heads=self._num_heads,\n            head_dim=self._head_dim,\n        )\n\n    def _apply_output(\n        self,\n        attn_out: mx.array,\n        hidden_states: mx.array,\n        text_embeddings: mx.array,\n    ) -> mx.array:\n        residual = hidden_states\n\n        assert self._norm_state is not None\n\n        output = self.block._apply_feed_forward_and_projection(\n            norm_hidden_states=self._norm_state.norm_hidden,\n            attn_output=attn_out,\n            gate=self._norm_state.gate,\n        )\n\n        return residual + output\n"
  },
  {
    "path": "src/exo/worker/engines/image/models/qwen/__init__.py",
    "content": "from exo.worker.engines.image.models.qwen.adapter import QwenModelAdapter\nfrom exo.worker.engines.image.models.qwen.config import (\n    QWEN_IMAGE_CONFIG,\n    QWEN_IMAGE_EDIT_CONFIG,\n)\nfrom exo.worker.engines.image.models.qwen.edit_adapter import QwenEditModelAdapter\n\n__all__ = [\n    \"QwenModelAdapter\",\n    \"QwenEditModelAdapter\",\n    \"QWEN_IMAGE_CONFIG\",\n    \"QWEN_IMAGE_EDIT_CONFIG\",\n]\n"
  },
  {
    "path": "src/exo/worker/engines/image/models/qwen/adapter.py",
    "content": "from pathlib import Path\nfrom typing import Any\n\nimport mlx.core as mx\nfrom mflux.models.common.config import ModelConfig\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator\nfrom mflux.models.qwen.model.qwen_text_encoder.qwen_prompt_encoder import (\n    QwenPromptEncoder,\n)\nfrom mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer\nfrom mflux.models.qwen.variants.txt2img.qwen_image import QwenImage\n\nfrom exo.worker.engines.image.config import ImageModelConfig\nfrom exo.worker.engines.image.models.base import (\n    ModelAdapter,\n    PromptData,\n    RotaryEmbeddings,\n)\nfrom exo.worker.engines.image.models.qwen.wrappers import QwenJointBlockWrapper\nfrom exo.worker.engines.image.pipeline.block_wrapper import (\n    JointBlockWrapper,\n    SingleBlockWrapper,\n)\n\n\nclass QwenPromptData(PromptData):\n    def __init__(\n        self,\n        prompt_embeds: mx.array,\n        prompt_mask: mx.array,\n        negative_prompt_embeds: mx.array,\n        negative_prompt_mask: mx.array,\n    ):\n        self._prompt_embeds = prompt_embeds\n        self._prompt_mask = prompt_mask\n        self._negative_prompt_embeds = negative_prompt_embeds\n        self._negative_prompt_mask = negative_prompt_mask\n\n    @property\n    def prompt_embeds(self) -> mx.array:\n        return self._prompt_embeds\n\n    @property\n    def pooled_prompt_embeds(self) -> mx.array:\n        return self._prompt_embeds\n\n    @property\n    def negative_prompt_embeds(self) -> mx.array:\n        return self._negative_prompt_embeds\n\n    @property\n    def negative_pooled_prompt_embeds(self) -> mx.array:\n        return self._negative_prompt_embeds\n\n    def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array:\n        if positive:\n            return self._prompt_mask\n        else:\n            return self._negative_prompt_mask\n\n    @property\n    def cond_image_grid(\n        self,\n    ) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:\n        return None\n\n    @property\n    def conditioning_latents(self) -> mx.array | None:\n        return None\n\n    @property\n    def kontext_image_ids(self) -> mx.array | None:\n        return None\n\n    def get_batched_cfg_data(\n        self,\n    ) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:\n        \"\"\"Batch positive and negative embeddings for CFG with batch_size=2.\n\n        Pads shorter sequence to max length using zeros for embeddings\n        and zeros (masked) for attention mask.\n\n        Returns:\n            Tuple of (batched_embeds, batched_mask, None, conditioning_latents)\n            - batched_embeds: [2, max_seq, hidden]\n            - batched_mask: [2, max_seq]\n            - None for pooled (Qwen doesn't use it)\n            - conditioning_latents: [2, latent_seq, latent_dim] or None\n        \"\"\"\n        pos_embeds = self._prompt_embeds\n        neg_embeds = self._negative_prompt_embeds\n        pos_mask = self._prompt_mask\n        neg_mask = self._negative_prompt_mask\n\n        pos_seq_len = pos_embeds.shape[1]\n        neg_seq_len = neg_embeds.shape[1]\n        max_seq_len = max(pos_seq_len, neg_seq_len)\n        hidden_dim = pos_embeds.shape[2]\n\n        if pos_seq_len < max_seq_len:\n            pad_len = max_seq_len - pos_seq_len\n            pos_embeds = mx.concatenate(\n                [\n                    pos_embeds,\n                    mx.zeros((1, pad_len, hidden_dim), dtype=pos_embeds.dtype),\n                ],\n                axis=1,\n            )\n            pos_mask = mx.concatenate(\n                [pos_mask, mx.zeros((1, pad_len), dtype=pos_mask.dtype)],\n                axis=1,\n            )\n\n        elif neg_seq_len < max_seq_len:\n            pad_len = max_seq_len - neg_seq_len\n            neg_embeds = mx.concatenate(\n                [\n                    neg_embeds,\n                    mx.zeros((1, pad_len, hidden_dim), dtype=neg_embeds.dtype),\n                ],\n                axis=1,\n            )\n            neg_mask = mx.concatenate(\n                [neg_mask, mx.zeros((1, pad_len), dtype=neg_mask.dtype)],\n                axis=1,\n            )\n\n        batched_embeds = mx.concatenate([pos_embeds, neg_embeds], axis=0)\n        batched_mask = mx.concatenate([pos_mask, neg_mask], axis=0)\n\n        # TODO(ciaran): currently None but maybe we will deduplicate with edit\n        # adapter\n        cond_latents = self.conditioning_latents\n        if cond_latents is not None:\n            cond_latents = mx.concatenate([cond_latents, cond_latents], axis=0)\n\n        return batched_embeds, batched_mask, None, cond_latents\n\n    def get_cfg_branch_data(\n        self, positive: bool\n    ) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:\n        if positive:\n            return (\n                self._prompt_embeds,\n                self._prompt_mask,\n                None,\n                self.conditioning_latents,\n            )\n        else:\n            return (\n                self._negative_prompt_embeds,\n                self._negative_prompt_mask,\n                None,\n                self.conditioning_latents,\n            )\n\n\nclass QwenModelAdapter(ModelAdapter[QwenImage, QwenTransformer]):\n    \"\"\"Adapter for Qwen-Image model.\n\n    Key differences from Flux:\n    - Single text encoder (vs dual T5+CLIP)\n    - 60 joint-style blocks, no single blocks\n    - 3D RoPE returning ((img_cos, img_sin), (txt_cos, txt_sin))\n    - Norm-preserving CFG with negative prompts\n    - Uses attention mask for variable-length text\n    \"\"\"\n\n    def __init__(\n        self,\n        config: ImageModelConfig,\n        model_id: str,\n        local_path: Path,\n        quantize: int | None = None,\n    ):\n        self._config = config\n        self._model = QwenImage(\n            model_config=ModelConfig.from_name(model_name=model_id, base_model=None),\n            model_path=str(local_path),\n            quantize=quantize,\n        )\n        self._transformer = self._model.transformer\n\n    @property\n    def hidden_dim(self) -> int:\n        return self._transformer.inner_dim\n\n    @property\n    def needs_cfg(self) -> bool:\n        gs = self._config.guidance_scale\n        return gs is not None and gs > 1.0\n\n    def _get_latent_creator(self) -> type:\n        return QwenLatentCreator\n\n    def get_joint_block_wrappers(\n        self,\n        text_seq_len: int,\n        encoder_hidden_states_mask: mx.array | None = None,\n    ) -> list[JointBlockWrapper[Any]]:\n        \"\"\"Create wrapped joint blocks for Qwen.\"\"\"\n        return [\n            QwenJointBlockWrapper(block, text_seq_len, encoder_hidden_states_mask)\n            for block in self._transformer.transformer_blocks\n        ]\n\n    def get_single_block_wrappers(\n        self,\n        text_seq_len: int,\n    ) -> list[SingleBlockWrapper[Any]]:\n        return []\n\n    def slice_transformer_blocks(\n        self,\n        start_layer: int,\n        end_layer: int,\n    ):\n        self._transformer.transformer_blocks = self._transformer.transformer_blocks[\n            start_layer:end_layer\n        ]\n\n    def encode_prompt(\n        self, prompt: str, negative_prompt: str | None = None\n    ) -> QwenPromptData:\n        assert isinstance(self.model.prompt_cache, dict)\n        assert isinstance(self.model.tokenizers, dict)\n\n        if negative_prompt is None or negative_prompt == \"\":\n            negative_prompt = \" \"\n\n        prompt_embeds, prompt_mask, neg_embeds, neg_mask = (\n            QwenPromptEncoder.encode_prompt(\n                prompt=prompt,\n                negative_prompt=negative_prompt,\n                prompt_cache=self.model.prompt_cache,\n                qwen_tokenizer=self.model.tokenizers[\"qwen\"],  # pyright: ignore[reportAny]\n                qwen_text_encoder=self.model.text_encoder,\n            )\n        )\n\n        return QwenPromptData(\n            prompt_embeds=prompt_embeds,\n            prompt_mask=prompt_mask,\n            negative_prompt_embeds=neg_embeds,\n            negative_prompt_mask=neg_mask,\n        )\n\n    def compute_embeddings(\n        self,\n        hidden_states: mx.array,\n        prompt_embeds: mx.array,\n    ) -> tuple[mx.array, mx.array]:\n        embedded_hidden = self._transformer.img_in(hidden_states)\n        encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)\n        embedded_encoder = self._transformer.txt_in(encoder_hidden_states)\n        return embedded_hidden, embedded_encoder\n\n    def compute_text_embeddings(\n        self,\n        t: int,\n        runtime_config: Config,\n        pooled_prompt_embeds: mx.array | None = None,\n        hidden_states: mx.array | None = None,\n    ) -> mx.array:\n        # Use hidden_states if provided, otherwise fall back to pooled_prompt_embeds\n        # (which for Qwen is the same as prompt_embeds)\n        ref_tensor = (\n            hidden_states if hidden_states is not None else pooled_prompt_embeds\n        )\n        if ref_tensor is None:\n            raise ValueError(\n                \"Either hidden_states or pooled_prompt_embeds is required \"\n                \"for Qwen text embeddings\"\n            )\n\n        timestep = QwenTransformer._compute_timestep(t, runtime_config)  # noqa: SLF001\n        batch_size = ref_tensor.shape[0]\n        timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)\n        return self._transformer.time_text_embed(timestep, ref_tensor)  # pyright: ignore[reportAny]\n\n    def compute_rotary_embeddings(\n        self,\n        prompt_embeds: mx.array,\n        runtime_config: Config,\n        encoder_hidden_states_mask: mx.array | None = None,\n        cond_image_grid: tuple[int, int, int]\n        | list[tuple[int, int, int]]\n        | None = None,\n        kontext_image_ids: mx.array | None = None,\n    ) -> RotaryEmbeddings:\n        if encoder_hidden_states_mask is None:\n            raise ValueError(\n                \"encoder_hidden_states_mask is required for Qwen RoPE computation\"\n            )\n\n        return QwenTransformer._compute_rotary_embeddings(\n            encoder_hidden_states_mask=encoder_hidden_states_mask,\n            pos_embed=self._transformer.pos_embed,  # pyright: ignore[reportAny]\n            config=runtime_config,\n            cond_image_grid=cond_image_grid,\n        )\n\n    def apply_guidance(\n        self,\n        noise_positive: mx.array,\n        noise_negative: mx.array,\n        guidance_scale: float,\n    ) -> mx.array:\n        return self._model.compute_guided_noise(\n            noise=noise_positive,\n            noise_negative=noise_negative,\n            guidance=guidance_scale,\n        )\n"
  },
  {
    "path": "src/exo/worker/engines/image/models/qwen/config.py",
    "content": "from exo.worker.engines.image.config import (\n    BlockType,\n    ImageModelConfig,\n    TransformerBlockConfig,\n)\n\nQWEN_IMAGE_CONFIG = ImageModelConfig(\n    model_family=\"qwen\",\n    block_configs=(\n        TransformerBlockConfig(\n            block_type=BlockType.JOINT, count=60, has_separate_text_output=True\n        ),\n    ),\n    default_steps={\"low\": 10, \"medium\": 25, \"high\": 50},\n    num_sync_steps=7,\n    guidance_scale=3.5,  # Set to None or < 1.0 to disable CFG\n)\n\nQWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(\n    model_family=\"qwen-edit\",\n    block_configs=(\n        TransformerBlockConfig(\n            block_type=BlockType.JOINT, count=60, has_separate_text_output=True\n        ),\n    ),\n    default_steps={\"low\": 10, \"medium\": 25, \"high\": 50},\n    num_sync_steps=7,\n    guidance_scale=3.5,\n)\n"
  },
  {
    "path": "src/exo/worker/engines/image/models/qwen/edit_adapter.py",
    "content": "import math\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any\n\nimport mlx.core as mx\nfrom mflux.models.common.config.config import Config\nfrom mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator\nfrom mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer\nfrom mflux.models.qwen.variants.edit.qwen_edit_util import QwenEditUtil\nfrom mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit\n\nfrom exo.worker.engines.image.config import ImageModelConfig\nfrom exo.worker.engines.image.models.base import (\n    ModelAdapter,\n    PromptData,\n    RotaryEmbeddings,\n)\nfrom exo.worker.engines.image.models.qwen.wrappers import QwenJointBlockWrapper\nfrom exo.worker.engines.image.pipeline.block_wrapper import (\n    JointBlockWrapper,\n    SingleBlockWrapper,\n)\n\n\n@dataclass(frozen=True)\nclass EditImageDimensions:\n    vl_width: int\n    vl_height: int\n    vae_width: int\n    vae_height: int\n    image_paths: list[str]\n\n\nclass QwenEditPromptData(PromptData):\n    def __init__(\n        self,\n        prompt_embeds: mx.array,\n        prompt_mask: mx.array,\n        negative_prompt_embeds: mx.array,\n        negative_prompt_mask: mx.array,\n        conditioning_latents: mx.array,\n        qwen_image_ids: mx.array,\n        cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]],\n    ):\n        self._prompt_embeds = prompt_embeds\n        self._prompt_mask = prompt_mask\n        self._negative_prompt_embeds = negative_prompt_embeds\n        self._negative_prompt_mask = negative_prompt_mask\n        self._conditioning_latents = conditioning_latents\n        self._qwen_image_ids = qwen_image_ids\n        self._cond_image_grid = cond_image_grid\n\n    @property\n    def prompt_embeds(self) -> mx.array:\n        return self._prompt_embeds\n\n    @property\n    def pooled_prompt_embeds(self) -> mx.array:\n        return self._prompt_embeds\n\n    @property\n    def negative_prompt_embeds(self) -> mx.array:\n        return self._negative_prompt_embeds\n\n    @property\n    def negative_pooled_prompt_embeds(self) -> mx.array:\n        return self._negative_prompt_embeds\n\n    def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array:\n        if positive:\n            return self._prompt_mask\n        else:\n            return self._negative_prompt_mask\n\n    @property\n    def cond_image_grid(self) -> tuple[int, int, int] | list[tuple[int, int, int]]:\n        return self._cond_image_grid\n\n    @property\n    def conditioning_latents(self) -> mx.array:\n        return self._conditioning_latents\n\n    @property\n    def qwen_image_ids(self) -> mx.array:\n        return self._qwen_image_ids\n\n    @property\n    def kontext_image_ids(self) -> mx.array | None:\n        return None\n\n    @property\n    def is_edit_mode(self) -> bool:\n        return True\n\n    def get_batched_cfg_data(\n        self,\n    ) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:\n        \"\"\"Batch positive and negative embeddings for CFG with batch_size=2.\n\n        Pads shorter sequence to max length using zeros for embeddings\n        and zeros (masked) for attention mask. Duplicates conditioning\n        latents for both positive and negative passes.\n\n        Returns:\n            Tuple of (batched_embeds, batched_mask, None, batched_cond_latents)\n            - batched_embeds: [2, max_seq, hidden]\n            - batched_mask: [2, max_seq]\n            - None for pooled (Qwen doesn't use it)\n            - batched_cond_latents: [2, latent_seq, latent_dim]\n            TODO(ciaran): type this\n        \"\"\"\n        pos_embeds = self._prompt_embeds\n        neg_embeds = self._negative_prompt_embeds\n        pos_mask = self._prompt_mask\n        neg_mask = self._negative_prompt_mask\n\n        pos_seq_len = pos_embeds.shape[1]\n        neg_seq_len = neg_embeds.shape[1]\n        max_seq_len = max(pos_seq_len, neg_seq_len)\n        hidden_dim = pos_embeds.shape[2]\n\n        if pos_seq_len < max_seq_len:\n            pad_len = max_seq_len - pos_seq_len\n            pos_embeds = mx.concatenate(\n                [\n                    pos_embeds,\n                    mx.zeros((1, pad_len, hidden_dim), dtype=pos_embeds.dtype),\n                ],\n                axis=1,\n            )\n            pos_mask = mx.concatenate(\n                [pos_mask, mx.zeros((1, pad_len), dtype=pos_mask.dtype)],\n                axis=1,\n            )\n\n        if neg_seq_len < max_seq_len:\n            pad_len = max_seq_len - neg_seq_len\n            neg_embeds = mx.concatenate(\n                [\n                    neg_embeds,\n                    mx.zeros((1, pad_len, hidden_dim), dtype=neg_embeds.dtype),\n                ],\n                axis=1,\n            )\n            neg_mask = mx.concatenate(\n                [neg_mask, mx.zeros((1, pad_len), dtype=neg_mask.dtype)],\n                axis=1,\n            )\n\n        batched_embeds = mx.concatenate([pos_embeds, neg_embeds], axis=0)\n        batched_mask = mx.concatenate([pos_mask, neg_mask], axis=0)\n\n        batched_cond_latents = mx.concatenate(\n            [self._conditioning_latents, self._conditioning_latents], axis=0\n        )\n\n        return batched_embeds, batched_mask, None, batched_cond_latents\n\n    def get_cfg_branch_data(\n        self, positive: bool\n    ) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:\n        if positive:\n            return (\n                self._prompt_embeds,\n                self._prompt_mask,\n                None,\n                self._conditioning_latents,\n            )\n        else:\n            return (\n                self._negative_prompt_embeds,\n                self._negative_prompt_mask,\n                None,\n                self._conditioning_latents,\n            )\n\n\nclass QwenEditModelAdapter(ModelAdapter[QwenImageEdit, QwenTransformer]):\n    \"\"\"Adapter for Qwen-Image-Edit model.\n\n    Key differences from standard QwenModelAdapter:\n    - Uses QwenImageEdit model with vision-language components\n    - Encodes prompts WITH input images via VL tokenizer/encoder\n    - Creates conditioning latents from input images\n    - Supports image editing with concatenated latents during diffusion\n    \"\"\"\n\n    def __init__(\n        self,\n        config: ImageModelConfig,\n        model_id: str,\n        local_path: Path,\n        quantize: int | None = None,\n    ):\n        self._config = config\n        self._model = QwenImageEdit(\n            quantize=quantize,\n            model_path=str(local_path),\n        )\n        self._transformer = self._model.transformer\n\n        self._edit_dimensions: EditImageDimensions | None = None\n\n    @property\n    def config(self) -> ImageModelConfig:\n        return self._config\n\n    @property\n    def model(self) -> QwenImageEdit:\n        return self._model\n\n    @property\n    def transformer(self) -> QwenTransformer:\n        return self._transformer\n\n    @property\n    def hidden_dim(self) -> int:\n        return self._transformer.inner_dim\n\n    @property\n    def needs_cfg(self) -> bool:\n        gs = self._config.guidance_scale\n        return gs is not None and gs > 1.0\n\n    def _get_latent_creator(self) -> type[QwenLatentCreator]:\n        return QwenLatentCreator\n\n    def get_joint_block_wrappers(\n        self,\n        text_seq_len: int,\n        encoder_hidden_states_mask: mx.array | None = None,\n    ) -> list[JointBlockWrapper[Any]]:\n        \"\"\"Create wrapped joint blocks for Qwen Edit.\"\"\"\n        return [\n            QwenJointBlockWrapper(block, text_seq_len, encoder_hidden_states_mask)\n            for block in self._transformer.transformer_blocks\n        ]\n\n    def get_single_block_wrappers(\n        self,\n        text_seq_len: int,\n    ) -> list[SingleBlockWrapper[Any]]:\n        \"\"\"Qwen has no single blocks.\"\"\"\n        return []\n\n    def slice_transformer_blocks(\n        self,\n        start_layer: int,\n        end_layer: int,\n    ):\n        self._transformer.transformer_blocks = self._transformer.transformer_blocks[\n            start_layer:end_layer\n        ]\n\n    def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:\n        \"\"\"Compute and store dimensions from input image.\n\n        Also stores image_paths for use in encode_prompt().\n\n        Returns:\n            (output_width, output_height) for runtime config\n        \"\"\"\n        vl_w, vl_h, vae_w, vae_h, out_w, out_h = self._compute_dimensions_from_image(\n            image_path\n        )\n        self._edit_dimensions = EditImageDimensions(\n            vl_width=vl_w,\n            vl_height=vl_h,\n            vae_width=vae_w,\n            vae_height=vae_h,\n            image_paths=[str(image_path)],\n        )\n        return out_w, out_h\n\n    def create_latents(self, seed: int, runtime_config: Config) -> mx.array:\n        \"\"\"Create initial noise latents (pure noise for edit mode).\"\"\"\n        return QwenLatentCreator.create_noise(\n            seed=seed,\n            height=runtime_config.height,\n            width=runtime_config.width,\n        )\n\n    def encode_prompt(\n        self, prompt: str, negative_prompt: str | None = None\n    ) -> QwenEditPromptData:\n        dims = self._edit_dimensions\n        if dims is None:\n            raise RuntimeError(\n                \"set_image_dimensions() must be called before encode_prompt() \"\n                \"for QwenEditModelAdapter\"\n            )\n\n        if negative_prompt is None or negative_prompt == \"\":\n            negative_prompt = \" \"\n\n        # TODO(ciaran): config is untyped and unused, unsure if Config or RuntimeConfig is intended\n        (\n            prompt_embeds,\n            prompt_mask,\n            negative_prompt_embeds,\n            negative_prompt_mask,\n        ) = self._model._encode_prompts_with_images(\n            prompt,\n            negative_prompt,\n            dims.image_paths,\n            self._config,  # pyright: ignore[reportArgumentType]\n            dims.vl_width,\n            dims.vl_height,\n        )\n\n        (\n            conditioning_latents,\n            qwen_image_ids,\n            cond_h_patches,\n            cond_w_patches,\n            num_images,\n        ) = QwenEditUtil.create_image_conditioning_latents(  # pyright: ignore[reportUnknownMemberType]\n            vae=self._model.vae,\n            height=dims.vae_height,\n            width=dims.vae_width,\n            image_paths=dims.image_paths,\n            vl_width=dims.vl_width,\n            vl_height=dims.vl_height,\n        )\n\n        if num_images > 1:\n            cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] = [\n                (1, cond_h_patches, cond_w_patches) for _ in range(num_images)\n            ]\n        else:\n            cond_image_grid = (1, cond_h_patches, cond_w_patches)\n\n        return QwenEditPromptData(\n            prompt_embeds=prompt_embeds,\n            prompt_mask=prompt_mask,\n            negative_prompt_embeds=negative_prompt_embeds,\n            negative_prompt_mask=negative_prompt_mask,\n            conditioning_latents=conditioning_latents,\n            qwen_image_ids=qwen_image_ids,\n            cond_image_grid=cond_image_grid,\n        )\n\n    def compute_embeddings(\n        self,\n        hidden_states: mx.array,\n        prompt_embeds: mx.array,\n    ) -> tuple[mx.array, mx.array]:\n        embedded_hidden = self._transformer.img_in(hidden_states)\n        encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)\n        embedded_encoder = self._transformer.txt_in(encoder_hidden_states)\n        return embedded_hidden, embedded_encoder\n\n    def compute_text_embeddings(\n        self,\n        t: int,\n        runtime_config: Config,\n        pooled_prompt_embeds: mx.array | None = None,\n        hidden_states: mx.array | None = None,\n    ) -> mx.array:\n        ref_tensor = (\n            hidden_states if hidden_states is not None else pooled_prompt_embeds\n        )\n        if ref_tensor is None:\n            raise ValueError(\n                \"Either hidden_states or pooled_prompt_embeds is required \"\n                \"for Qwen text embeddings\"\n            )\n\n        timestep = QwenTransformer._compute_timestep(t, runtime_config)  # noqa: SLF001\n        batch_size = ref_tensor.shape[0]\n        timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)\n        return self._transformer.time_text_embed(timestep, ref_tensor)  # pyright: ignore[reportAny]\n\n    def compute_rotary_embeddings(\n        self,\n        prompt_embeds: mx.array,\n        runtime_config: Config,\n        encoder_hidden_states_mask: mx.array | None = None,\n        cond_image_grid: tuple[int, int, int]\n        | list[tuple[int, int, int]]\n        | None = None,\n        kontext_image_ids: mx.array | None = None,\n    ) -> RotaryEmbeddings:\n        if encoder_hidden_states_mask is None:\n            raise ValueError(\n                \"encoder_hidden_states_mask is required for Qwen RoPE computation\"\n            )\n\n        return QwenTransformer._compute_rotary_embeddings(\n            encoder_hidden_states_mask=encoder_hidden_states_mask,\n            pos_embed=self._transformer.pos_embed,  # pyright: ignore[reportAny]\n            config=runtime_config,\n            cond_image_grid=cond_image_grid,\n        )\n\n    def apply_guidance(\n        self,\n        noise_positive: mx.array,\n        noise_negative: mx.array,\n        guidance_scale: float,\n    ) -> mx.array:\n        from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage\n\n        return QwenImage.compute_guided_noise(\n            noise=noise_positive,\n            noise_negative=noise_negative,\n            guidance=guidance_scale,\n        )\n\n    def _compute_dimensions_from_image(\n        self, image_path: Path\n    ) -> tuple[int, int, int, int, int, int]:\n        from mflux.utils.image_util import ImageUtil\n\n        pil_image = ImageUtil.load_image(str(image_path)).convert(\"RGB\")\n        image_size = pil_image.size\n\n        # Vision-language dimensions (384x384 target area)\n        condition_image_size = 384 * 384\n        condition_ratio = image_size[0] / image_size[1]\n        vl_width = math.sqrt(condition_image_size * condition_ratio)\n        vl_height = vl_width / condition_ratio\n        vl_width = round(vl_width / 32) * 32\n        vl_height = round(vl_height / 32) * 32\n\n        # VAE dimensions (1024x1024 target area)\n        vae_image_size = 1024 * 1024\n        vae_ratio = image_size[0] / image_size[1]\n        vae_width = math.sqrt(vae_image_size * vae_ratio)\n        vae_height = vae_width / vae_ratio\n        vae_width = round(vae_width / 32) * 32\n        vae_height = round(vae_height / 32) * 32\n\n        # Output dimensions from input image aspect ratio\n        target_area = 1024 * 1024\n        ratio = image_size[0] / image_size[1]\n        output_width = math.sqrt(target_area * ratio)\n        output_height = output_width / ratio\n        output_width = round(output_width / 32) * 32\n        output_height = round(output_height / 32) * 32\n\n        # Ensure multiple of 16 for VAE\n        vae_scale_factor = 8\n        multiple_of = vae_scale_factor * 2\n        output_width = output_width // multiple_of * multiple_of\n        output_height = output_height // multiple_of * multiple_of\n\n        return (\n            int(vl_width),\n            int(vl_height),\n            int(vae_width),\n            int(vae_height),\n            int(output_width),\n            int(output_height),\n        )\n"
  },
  {
    "path": "src/exo/worker/engines/image/models/qwen/wrappers.py",
    "content": "from typing import final\n\nimport mlx.core as mx\nfrom mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention\nfrom mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (\n    QwenTransformerBlock,\n)\nfrom pydantic import BaseModel, ConfigDict\n\nfrom exo.worker.engines.image.models.base import RotaryEmbeddings\nfrom exo.worker.engines.image.pipeline.block_wrapper import JointBlockWrapper\n\n\n@final\nclass QwenStreamModulation(BaseModel):\n    model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True)\n\n    mod1: mx.array\n    mod2: mx.array\n    gate1: mx.array\n\n\nclass QwenJointBlockWrapper(JointBlockWrapper[QwenTransformerBlock]):\n    def __init__(\n        self,\n        block: QwenTransformerBlock,\n        text_seq_len: int,\n        encoder_hidden_states_mask: mx.array | None = None,\n    ):\n        super().__init__(block, text_seq_len)\n        self._encoder_hidden_states_mask = encoder_hidden_states_mask\n\n        self._num_heads = block.attn.num_heads\n        self._head_dim = block.attn.head_dim\n\n        # Intermediate state stored between _compute_qkv and _apply_output\n        self._img_mod: QwenStreamModulation | None = None\n        self._txt_mod: QwenStreamModulation | None = None\n\n    def set_encoder_mask(self, mask: mx.array | None) -> None:\n        \"\"\"Set the encoder hidden states mask for attention.\"\"\"\n        self._encoder_hidden_states_mask = mask\n\n    def _compute_qkv(\n        self,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        text_embeddings: mx.array,\n        rotary_embeddings: RotaryEmbeddings,\n        patch_mode: bool = False,\n    ) -> tuple[mx.array, mx.array, mx.array]:\n        assert isinstance(rotary_embeddings, tuple)\n\n        batch_size = hidden_states.shape[0]\n        img_seq_len = hidden_states.shape[1]\n        attn = self.block.attn\n\n        img_mod_params = self.block.img_mod_linear(\n            self.block.img_mod_silu(text_embeddings)  # pyright: ignore[reportUnknownArgumentType]\n        )\n        txt_mod_params = self.block.txt_mod_linear(\n            self.block.txt_mod_silu(text_embeddings)  # pyright: ignore[reportUnknownArgumentType]\n        )\n\n        img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)\n        txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)\n\n        img_normed = self.block.img_norm1(hidden_states)\n        img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)\n        self._img_mod = QwenStreamModulation(\n            mod1=img_mod1, mod2=img_mod2, gate1=img_gate1\n        )\n\n        txt_normed = self.block.txt_norm1(encoder_hidden_states)\n        txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)\n        self._txt_mod = QwenStreamModulation(\n            mod1=txt_mod1, mod2=txt_mod2, gate1=txt_gate1\n        )\n\n        img_query = attn.to_q(img_modulated)\n        img_key = attn.to_k(img_modulated)\n        img_value = attn.to_v(img_modulated)\n\n        txt_query = attn.add_q_proj(txt_modulated)\n        txt_key = attn.add_k_proj(txt_modulated)\n        txt_value = attn.add_v_proj(txt_modulated)\n\n        img_query = mx.reshape(\n            img_query, (batch_size, img_seq_len, self._num_heads, self._head_dim)\n        )\n        img_key = mx.reshape(\n            img_key, (batch_size, img_seq_len, self._num_heads, self._head_dim)\n        )\n        img_value = mx.reshape(\n            img_value, (batch_size, img_seq_len, self._num_heads, self._head_dim)\n        )\n\n        txt_query = mx.reshape(\n            txt_query,\n            (batch_size, self._text_seq_len, self._num_heads, self._head_dim),\n        )\n        txt_key = mx.reshape(\n            txt_key, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)\n        )\n        txt_value = mx.reshape(\n            txt_value, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)\n        )\n\n        img_query = attn.norm_q(img_query)\n        img_key = attn.norm_k(img_key)\n        txt_query = attn.norm_added_q(txt_query)\n        txt_key = attn.norm_added_k(txt_key)\n\n        (img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings\n\n        if patch_mode:\n            # Slice image RoPE for patch, keep full text RoPE\n            img_cos = img_cos[self._patch_start : self._patch_end]\n            img_sin = img_sin[self._patch_start : self._patch_end]\n\n        img_query = QwenAttention._apply_rope_qwen(img_query, img_cos, img_sin)\n        img_key = QwenAttention._apply_rope_qwen(img_key, img_cos, img_sin)\n        txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)\n        txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)\n\n        img_query = mx.transpose(img_query, (0, 2, 1, 3))\n        img_key = mx.transpose(img_key, (0, 2, 1, 3))\n        img_value = mx.transpose(img_value, (0, 2, 1, 3))\n\n        txt_query = mx.transpose(txt_query, (0, 2, 1, 3))\n        txt_key = mx.transpose(txt_key, (0, 2, 1, 3))\n        txt_value = mx.transpose(txt_value, (0, 2, 1, 3))\n\n        query = mx.concatenate([txt_query, img_query], axis=2)\n        key = mx.concatenate([txt_key, img_key], axis=2)\n        value = mx.concatenate([txt_value, img_value], axis=2)\n\n        return query, key, value\n\n    def _compute_attention(\n        self, query: mx.array, key: mx.array, value: mx.array\n    ) -> mx.array:\n        attn = self.block.attn\n\n        mask = QwenAttention._convert_mask_for_qwen(\n            mask=self._encoder_hidden_states_mask,\n            joint_seq_len=key.shape[2],\n            txt_seq_len=self._text_seq_len,\n        )\n\n        query_bshd = mx.transpose(query, (0, 2, 1, 3))\n        key_bshd = mx.transpose(key, (0, 2, 1, 3))\n        value_bshd = mx.transpose(value, (0, 2, 1, 3))\n\n        return attn._compute_attention_qwen(\n            query=query_bshd,\n            key=key_bshd,\n            value=value_bshd,\n            mask=mask,\n            block_idx=None,\n        )\n\n    def _apply_output(\n        self,\n        attn_out: mx.array,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        text_embeddings: mx.array,\n    ) -> tuple[mx.array, mx.array]:\n        attn = self.block.attn\n\n        assert self._img_mod is not None\n        assert self._txt_mod is not None\n\n        txt_attn_output = attn_out[:, : self._text_seq_len, :]\n        img_attn_output = attn_out[:, self._text_seq_len :, :]\n\n        img_attn_output = attn.attn_to_out[0](img_attn_output)  # pyright: ignore[reportAny]\n        txt_attn_output = attn.to_add_out(txt_attn_output)\n\n        hidden_states = hidden_states + self._img_mod.gate1 * img_attn_output  # pyright: ignore[reportAny]\n        encoder_hidden_states = (\n            encoder_hidden_states + self._txt_mod.gate1 * txt_attn_output\n        )\n\n        img_normed2 = self.block.img_norm2(hidden_states)\n        img_modulated2, img_gate2 = QwenTransformerBlock._modulate(\n            img_normed2, self._img_mod.mod2\n        )\n        img_mlp_output = self.block.img_ff(img_modulated2)  # pyright: ignore[reportAny]\n        hidden_states = hidden_states + img_gate2 * img_mlp_output  # pyright: ignore[reportAny]\n\n        txt_normed2 = self.block.txt_norm2(encoder_hidden_states)\n        txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(\n            txt_normed2, self._txt_mod.mod2\n        )\n        txt_mlp_output = self.block.txt_ff(txt_modulated2)  # pyright: ignore[reportAny]\n        encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output  # pyright: ignore[reportAny]\n\n        return encoder_hidden_states, hidden_states\n"
  },
  {
    "path": "src/exo/worker/engines/image/pipeline/__init__.py",
    "content": "from exo.worker.engines.image.pipeline.block_wrapper import (\n    BlockWrapperMode,\n    JointBlockWrapper,\n    SingleBlockWrapper,\n)\nfrom exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache\nfrom exo.worker.engines.image.pipeline.runner import DiffusionRunner\n\n__all__ = [\n    \"BlockWrapperMode\",\n    \"DiffusionRunner\",\n    \"ImagePatchKVCache\",\n    \"JointBlockWrapper\",\n    \"SingleBlockWrapper\",\n]\n"
  },
  {
    "path": "src/exo/worker/engines/image/pipeline/block_wrapper.py",
    "content": "from abc import ABC, abstractmethod\nfrom enum import Enum\nfrom typing import Generic, Self, TypeVar\n\nimport mlx.core as mx\n\nfrom exo.worker.engines.image.models.base import RotaryEmbeddings\nfrom exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache\n\nBlockT = TypeVar(\"BlockT\")\n\n\nclass BlockWrapperMode(Enum):\n    CACHING = \"caching\"  # Sync mode: compute full attention, populate cache\n    PATCHED = \"patched\"  # Async mode: compute patch attention, use cached KV\n\n\nclass BlockWrapperMixin:\n    \"\"\"Common cache management logic for block wrappers.\n\n    Including:\n    - KV cache creation and management\n    - Mode\n    - Patch range setting\n    \"\"\"\n\n    _text_seq_len: int\n    _kv_cache: ImagePatchKVCache | None\n    _mode: BlockWrapperMode\n    _patch_start: int\n    _patch_end: int\n\n    def _init_cache_state(self, text_seq_len: int) -> None:\n        self._text_seq_len = text_seq_len\n        self._kv_cache = None\n        self._mode = BlockWrapperMode.CACHING\n        self._patch_start = 0\n        self._patch_end = 0\n\n    def set_patch(\n        self,\n        mode: BlockWrapperMode,\n        patch_start: int = 0,\n        patch_end: int = 0,\n    ) -> Self:\n        \"\"\"Set mode and patch range.\n\n        Args:\n            mode: CACHING (full attention) or PATCHED (use cached KV)\n            patch_start: Start token index within image (for PATCHED mode)\n            patch_end: End token index within image (for PATCHED mode)\n\n        Returns:\n            Self for method chaining\n        \"\"\"\n        self._mode = mode\n        self._patch_start = patch_start\n        self._patch_end = patch_end\n        return self\n\n    def set_text_seq_len(self, text_seq_len: int) -> None:\n        self._text_seq_len = text_seq_len\n\n    def _get_active_cache(self) -> ImagePatchKVCache | None:\n        return self._kv_cache\n\n    def _ensure_cache(self, img_key: mx.array) -> None:\n        if self._kv_cache is None:\n            batch, num_heads, img_seq_len, head_dim = img_key.shape\n            self._kv_cache = ImagePatchKVCache(\n                batch_size=batch,\n                num_heads=num_heads,\n                image_seq_len=img_seq_len,\n                head_dim=head_dim,\n            )\n\n    def _cache_full_image_kv(self, img_key: mx.array, img_value: mx.array) -> None:\n        self._ensure_cache(img_key)\n        cache = self._get_active_cache()\n        assert cache is not None\n        cache.update_image_patch(0, img_key.shape[2], img_key, img_value)\n\n    def _cache_patch_kv(self, img_key: mx.array, img_value: mx.array) -> None:\n        cache = self._get_active_cache()\n        assert cache is not None\n        cache.update_image_patch(self._patch_start, self._patch_end, img_key, img_value)\n\n    def _get_full_kv(\n        self, text_key: mx.array, text_value: mx.array\n    ) -> tuple[mx.array, mx.array]:\n        cache = self._get_active_cache()\n        assert cache is not None\n        return cache.get_full_kv(text_key, text_value)\n\n    def reset_cache(self) -> None:\n        self._kv_cache = None\n\n\nclass JointBlockWrapper(BlockWrapperMixin, ABC, Generic[BlockT]):\n    \"\"\"Base class for joint transformer block wrappers with pipefusion support.\n\n    The wrapper:\n    - Owns its KV cache (created lazily on first CACHING forward)\n    - Controls the forward pass flow (CACHING vs PATCHED mode)\n    - Handles patch slicing and cache operations\n    \"\"\"\n\n    block: BlockT\n\n    def __init__(self, block: BlockT, text_seq_len: int):\n        self.block = block\n        self._init_cache_state(text_seq_len)\n\n    def set_encoder_mask(self, mask: mx.array | None) -> None:  # noqa: B027\n        \"\"\"Set the encoder hidden states mask for attention.\n\n        Override in subclasses that use attention masks\n        Default is a no-op for models that don't use masks\n        \"\"\"\n        del mask  # Unused in base class\n\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        text_embeddings: mx.array,\n        rotary_embeddings: RotaryEmbeddings,\n    ) -> tuple[mx.array, mx.array]:\n        if self._mode == BlockWrapperMode.CACHING:\n            return self._forward_caching(\n                hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings\n            )\n        return self._forward_patched(\n            hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings\n        )\n\n    def _forward_caching(\n        self,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        text_embeddings: mx.array,\n        rotary_embeddings: RotaryEmbeddings,\n    ) -> tuple[mx.array, mx.array]:\n        \"\"\"CACHING mode: Full attention, store image K/V in cache.\"\"\"\n        query, key, value = self._compute_qkv(\n            hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings\n        )\n\n        img_key = key[:, :, self._text_seq_len :, :]\n        img_value = value[:, :, self._text_seq_len :, :]\n        self._cache_full_image_kv(img_key, img_value)\n\n        attn_out = self._compute_attention(query, key, value)\n\n        return self._apply_output(\n            attn_out, hidden_states, encoder_hidden_states, text_embeddings\n        )\n\n    def _forward_patched(\n        self,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        text_embeddings: mx.array,\n        rotary_embeddings: RotaryEmbeddings,\n    ) -> tuple[mx.array, mx.array]:\n        # hidden_states is already the patch (provided by runner)\n        patch_hidden = hidden_states\n\n        query, key, value = self._compute_qkv(\n            patch_hidden,\n            encoder_hidden_states,\n            text_embeddings,\n            rotary_embeddings,\n            patch_mode=True,\n        )\n\n        text_key = key[:, :, : self._text_seq_len, :]\n        text_value = value[:, :, : self._text_seq_len, :]\n        img_key = key[:, :, self._text_seq_len :, :]\n        img_value = value[:, :, self._text_seq_len :, :]\n\n        self._cache_patch_kv(img_key, img_value)\n        full_key, full_value = self._get_full_kv(text_key, text_value)\n\n        attn_out = self._compute_attention(query, full_key, full_value)\n\n        return self._apply_output(\n            attn_out, patch_hidden, encoder_hidden_states, text_embeddings\n        )\n\n    @abstractmethod\n    def _compute_qkv(\n        self,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        text_embeddings: mx.array,\n        rotary_embeddings: RotaryEmbeddings,\n        patch_mode: bool = False,\n    ) -> tuple[mx.array, mx.array, mx.array]: ...\n\n    @abstractmethod\n    def _compute_attention(\n        self, query: mx.array, key: mx.array, value: mx.array\n    ) -> mx.array: ...\n\n    @abstractmethod\n    def _apply_output(\n        self,\n        attn_out: mx.array,\n        hidden_states: mx.array,\n        encoder_hidden_states: mx.array,\n        text_embeddings: mx.array,\n    ) -> tuple[mx.array, mx.array]: ...\n\n\nclass SingleBlockWrapper(BlockWrapperMixin, ABC, Generic[BlockT]):\n    \"\"\"Base class for single-stream transformer block wrappers.\n\n    Similar to JointBlockWrapper but for blocks that operate on a single\n    concatenated [text, image] stream rather than separate streams.\n    \"\"\"\n\n    block: BlockT\n\n    def __init__(self, block: BlockT, text_seq_len: int):\n        self.block = block\n        self._init_cache_state(text_seq_len)\n\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        text_embeddings: mx.array,\n        rotary_embeddings: RotaryEmbeddings,\n    ) -> mx.array:\n        if self._mode == BlockWrapperMode.CACHING:\n            return self._forward_caching(\n                hidden_states, text_embeddings, rotary_embeddings\n            )\n        return self._forward_patched(hidden_states, text_embeddings, rotary_embeddings)\n\n    def _forward_caching(\n        self,\n        hidden_states: mx.array,\n        text_embeddings: mx.array,\n        rotary_embeddings: RotaryEmbeddings,\n    ) -> mx.array:\n        \"\"\"CACHING mode: Full attention, store image K/V in cache.\"\"\"\n        query, key, value = self._compute_qkv(\n            hidden_states, text_embeddings, rotary_embeddings\n        )\n\n        img_key = key[:, :, self._text_seq_len :, :]\n        img_value = value[:, :, self._text_seq_len :, :]\n        self._cache_full_image_kv(img_key, img_value)\n\n        attn_out = self._compute_attention(query, key, value)\n\n        return self._apply_output(attn_out, hidden_states, text_embeddings)\n\n    def _forward_patched(\n        self,\n        hidden_states: mx.array,\n        text_embeddings: mx.array,\n        rotary_embeddings: RotaryEmbeddings,\n    ) -> mx.array:\n        \"\"\"PATCHED mode: Compute patch Q/K/V, use cached image K/V for attention.\"\"\"\n        query, key, value = self._compute_qkv(\n            hidden_states, text_embeddings, rotary_embeddings, patch_mode=True\n        )\n\n        text_key = key[:, :, : self._text_seq_len, :]\n        text_value = value[:, :, : self._text_seq_len, :]\n        img_key = key[:, :, self._text_seq_len :, :]\n        img_value = value[:, :, self._text_seq_len :, :]\n\n        self._cache_patch_kv(img_key, img_value)\n        full_key, full_value = self._get_full_kv(text_key, text_value)\n\n        attn_out = self._compute_attention(query, full_key, full_value)\n\n        return self._apply_output(attn_out, hidden_states, text_embeddings)\n\n    @abstractmethod\n    def _compute_qkv(\n        self,\n        hidden_states: mx.array,\n        text_embeddings: mx.array,\n        rotary_embeddings: RotaryEmbeddings,\n        patch_mode: bool = False,\n    ) -> tuple[mx.array, mx.array, mx.array]: ...\n\n    @abstractmethod\n    def _compute_attention(\n        self, query: mx.array, key: mx.array, value: mx.array\n    ) -> mx.array: ...\n\n    @abstractmethod\n    def _apply_output(\n        self,\n        attn_out: mx.array,\n        hidden_states: mx.array,\n        text_embeddings: mx.array,\n    ) -> mx.array: ...\n"
  },
  {
    "path": "src/exo/worker/engines/image/pipeline/kv_cache.py",
    "content": "import mlx.core as mx\n\n\nclass ImagePatchKVCache:\n    \"\"\"KV cache that stores only IMAGE K/V with patch-level updates.\n\n    Only caches image K/V since:\n    - Text K/V is always computed fresh (same for all patches)\n    - Only image portion needs stale/fresh cache management across patches\n    \"\"\"\n\n    def __init__(\n        self,\n        batch_size: int,\n        num_heads: int,\n        image_seq_len: int,\n        head_dim: int,\n        dtype: mx.Dtype = mx.float32,\n    ):\n        self.batch_size = batch_size\n        self.num_heads = num_heads\n        self.image_seq_len = image_seq_len\n        self.head_dim = head_dim\n        self._dtype = dtype\n\n        self.key_cache = mx.zeros(\n            (batch_size, num_heads, image_seq_len, head_dim), dtype=dtype\n        )\n        self.value_cache = mx.zeros(\n            (batch_size, num_heads, image_seq_len, head_dim), dtype=dtype\n        )\n\n    def update_image_patch(\n        self, patch_start: int, patch_end: int, key: mx.array, value: mx.array\n    ) -> None:\n        \"\"\"Update cache with fresh K/V for an image patch slice.\n\n        Args:\n            patch_start: Start token index within image portion (0-indexed)\n            patch_end: End token index within image portion\n            key: Fresh key tensor [batch, heads, patch_seq_len, head_dim]\n            value: Fresh value tensor [batch, heads, patch_seq_len, head_dim]\n        \"\"\"\n        self.key_cache[:, :, patch_start:patch_end, :] = key\n        self.value_cache[:, :, patch_start:patch_end, :] = value\n\n    def get_full_kv(\n        self, text_key: mx.array, text_value: mx.array\n    ) -> tuple[mx.array, mx.array]:\n        \"\"\"Return full K/V by concatenating fresh text K/V with cached image K/V.\n\n        Args:\n            text_key: Fresh text key tensor [batch, heads, text_seq_len, head_dim]\n            text_value: Fresh text value tensor [batch, heads, text_seq_len, head_dim]\n\n        Returns:\n            Tuple of (full_key, full_value) with shape [batch, heads, text+image, head_dim]\n        \"\"\"\n        full_key = mx.concatenate([text_key, self.key_cache], axis=2)\n        full_value = mx.concatenate([text_value, self.value_cache], axis=2)\n        return full_key, full_value\n\n    def reset(self) -> None:\n        \"\"\"Reset cache to zeros.\"\"\"\n        self.key_cache = mx.zeros(\n            (self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),\n            dtype=self._dtype,\n        )\n        self.value_cache = mx.zeros(\n            (self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),\n            dtype=self._dtype,\n        )\n"
  },
  {
    "path": "src/exo/worker/engines/image/pipeline/runner.py",
    "content": "from collections.abc import Iterator\nfrom dataclasses import dataclass\nfrom math import ceil\nfrom typing import Any, Optional, final\n\nimport mlx.core as mx\nfrom mflux.models.common.config.config import Config\nfrom mflux.utils.exceptions import StopImageGenerationException\nfrom tqdm import tqdm\n\nfrom exo.shared.constants import EXO_TRACING_ENABLED\nfrom exo.shared.tracing import (\n    clear_trace_buffer,\n    trace,\n)\nfrom exo.shared.types.worker.shards import CfgShardMetadata, PipelineShardMetadata\nfrom exo.worker.engines.image.config import ImageModelConfig\nfrom exo.worker.engines.image.models.base import (\n    ModelAdapter,\n    PromptData,\n    RotaryEmbeddings,\n)\nfrom exo.worker.engines.image.pipeline.block_wrapper import (\n    BlockWrapperMode,\n    JointBlockWrapper,\n    SingleBlockWrapper,\n)\n\n\n@final\n@dataclass(frozen=True)\nclass CfgBranch:\n    positive: bool\n    embeds: mx.array\n    mask: mx.array | None\n    pooled: mx.array | None\n    cond_latents: mx.array | None\n\n\ndef calculate_patch_heights(\n    latent_height: int, num_patches: int\n) -> tuple[list[int], int]:\n    patch_height = ceil(latent_height / num_patches)\n\n    actual_num_patches = ceil(latent_height / patch_height)\n    patch_heights = [patch_height] * (actual_num_patches - 1)\n\n    last_height = latent_height - patch_height * (actual_num_patches - 1)\n    patch_heights.append(last_height)\n\n    return patch_heights, actual_num_patches\n\n\ndef calculate_token_indices(\n    patch_heights: list[int], latent_width: int\n) -> list[tuple[int, int]]:\n    tokens_per_row = latent_width\n\n    token_ranges: list[tuple[int, int]] = []\n    cumulative_height = 0\n\n    for h in patch_heights:\n        start_token = tokens_per_row * cumulative_height\n        end_token = tokens_per_row * (cumulative_height + h)\n\n        token_ranges.append((start_token, end_token))\n        cumulative_height += h\n\n    return token_ranges\n\n\nclass DiffusionRunner:\n    \"\"\"Orchestrates the diffusion loop for image generation.\n\n    In distributed mode, it implements PipeFusion with:\n    - Sync pipeline for initial timesteps (full image, all devices in lockstep)\n    - Async pipeline for later timesteps (patches processed independently)\n    \"\"\"\n\n    def __init__(\n        self,\n        config: ImageModelConfig,\n        adapter: ModelAdapter[Any, Any],\n        group: Optional[mx.distributed.Group],\n        shard_metadata: PipelineShardMetadata | CfgShardMetadata,\n        num_patches: Optional[int] = None,\n    ):\n        self.config = config\n        self.adapter = adapter\n        self.group = group\n\n        self._init_cfg_topology(shard_metadata)\n\n        self.num_patches = (\n            num_patches if num_patches else max(1, self.pipeline_world_size)\n        )\n\n        self.total_joint = config.joint_block_count\n        self.total_single = config.single_block_count\n        self.total_layers = config.total_blocks\n\n        self._guidance_override: float | None = None\n\n        self._compute_assigned_blocks()\n\n    def _init_cfg_topology(\n        self, shard_metadata: PipelineShardMetadata | CfgShardMetadata\n    ) -> None:\n        \"\"\"Initialize CFG and pipeline topology from shard metadata.\n\n        Both CfgShardMetadata and PipelineShardMetadata represent pipeline parallel\n        execution. CFG adds a second parallel pipeline for negative prompt processing,\n        but within each pipeline group the communication pattern is identical.\n        \"\"\"\n        if self.group is None:\n            # Single node - no distributed communication\n            self.rank = 0\n            self.world_size = 1\n            self.start_layer = 0\n            self.end_layer = self.config.total_blocks\n            self.cfg_rank = 0\n            self.cfg_world_size = 1\n            self.cfg_parallel = False\n            self.pipeline_rank = 0\n            self.pipeline_world_size = 1\n            self.next_pipeline_rank: int | None = None\n            self.prev_pipeline_rank: int | None = None\n            self.cfg_peer_rank: int | None = None\n            self.first_pipeline_rank: int = 0\n            self.last_pipeline_rank: int = 0\n            return\n\n        # Common fields from base metadata\n        self.rank = shard_metadata.device_rank\n        self.world_size = shard_metadata.world_size\n        self.start_layer = shard_metadata.start_layer\n        self.end_layer = shard_metadata.end_layer\n\n        if isinstance(shard_metadata, CfgShardMetadata):\n            # CFG parallel: two independent pipelines\n            self.cfg_rank = shard_metadata.cfg_rank\n            self.cfg_world_size = shard_metadata.cfg_world_size\n            self.cfg_parallel = True\n            self.pipeline_rank = shard_metadata.pipeline_rank\n            self.pipeline_world_size = shard_metadata.pipeline_world_size\n        else:\n            # Pure pipeline: single pipeline group, sequential CFG\n            self.cfg_rank = 0\n            self.cfg_world_size = 1\n            self.cfg_parallel = False\n            self.pipeline_rank = shard_metadata.device_rank\n            self.pipeline_world_size = shard_metadata.world_size\n\n        # Pipeline neighbor computation (same logic for both types)\n        is_first = self.pipeline_rank == 0\n        is_last = self.pipeline_rank == self.pipeline_world_size - 1\n\n        self.next_pipeline_rank = (\n            None\n            if is_last\n            else self._device_rank_for(self.cfg_rank, self.pipeline_rank + 1)\n        )\n        self.prev_pipeline_rank = (\n            None\n            if is_first\n            else self._device_rank_for(self.cfg_rank, self.pipeline_rank - 1)\n        )\n\n        # CFG peer is the corresponding last stage in the other CFG group\n        if self.cfg_parallel and is_last:\n            other_cfg_rank = 1 - self.cfg_rank\n            self.cfg_peer_rank = self._device_rank_for(\n                other_cfg_rank, self.pipeline_rank\n            )\n        else:\n            self.cfg_peer_rank = None\n\n        # First/last pipeline ranks for ring communication (latent broadcast)\n        self.first_pipeline_rank = self._device_rank_for(self.cfg_rank, 0)\n        self.last_pipeline_rank = self._device_rank_for(\n            self.cfg_rank, self.pipeline_world_size - 1\n        )\n\n    def _device_rank_for(self, cfg_rank: int, pipeline_rank: int) -> int:\n        \"\"\"Convert (cfg_rank, pipeline_rank) to device_rank in the ring topology.\n\n        Ring layout: [cfg0_pipe0, cfg0_pipe1, ..., cfg1_pipeN-1, cfg1_pipeN-2, ..., cfg1_pipe0]\n        Group 0 is in ascending order, group 1 is reversed so last stages are neighbors.\n        \"\"\"\n        if not self.cfg_parallel:\n            return pipeline_rank\n        if cfg_rank == 0:\n            return pipeline_rank\n        else:\n            return self.world_size - 1 - pipeline_rank\n\n    def _compute_assigned_blocks(self) -> None:\n        \"\"\"Determine which joint/single blocks this stage owns.\"\"\"\n        start = self.start_layer\n        end = self.end_layer\n\n        if end <= self.total_joint:\n            self.joint_start = start\n            self.joint_end = end\n            self.single_start = 0\n            self.single_end = 0\n        elif start >= self.total_joint:\n            self.joint_start = 0\n            self.joint_end = 0\n            self.single_start = start - self.total_joint\n            self.single_end = end - self.total_joint\n        else:\n            self.joint_start = start\n            self.joint_end = self.total_joint\n            self.single_start = 0\n            self.single_end = end - self.total_joint\n\n        self.has_joint_blocks = self.joint_end > self.joint_start\n        self.has_single_blocks = self.single_end > self.single_start\n\n        self.owns_concat_stage = self.has_joint_blocks and (\n            self.has_single_blocks or self.end_layer == self.total_joint\n        )\n\n        # Wrappers created lazily on first forward (need text_seq_len)\n        self.joint_block_wrappers: list[JointBlockWrapper[Any]] | None = None\n        self.single_block_wrappers: list[SingleBlockWrapper[Any]] | None = None\n        self._wrappers_initialized = False\n        self._current_text_seq_len: int | None = None\n\n    @property\n    def is_first_stage(self) -> bool:\n        return self.pipeline_rank == 0\n\n    @property\n    def is_last_stage(self) -> bool:\n        return self.pipeline_rank == self.pipeline_world_size - 1\n\n    @property\n    def is_distributed(self) -> bool:\n        return self.group is not None\n\n    def _get_effective_guidance_scale(self) -> float | None:\n        if self._guidance_override is not None:\n            return self._guidance_override\n        return self.config.guidance_scale\n\n    def _get_cfg_branches(self, prompt_data: PromptData) -> Iterator[CfgBranch]:\n        \"\"\"Yield the CFG branches this node should process.\n\n        - No CFG: yields one branch (positive)\n        - CFG parallel: yields one branch (our assigned branch)\n        - Sequential CFG: yields two branches (positive, then negative)\n        \"\"\"\n        if not self.adapter.needs_cfg:\n            embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive=True)\n            yield CfgBranch(\n                positive=True,\n                embeds=embeds,\n                mask=mask,\n                pooled=pooled,\n                cond_latents=cond,\n            )\n        elif self.cfg_parallel:\n            positive = self.cfg_rank == 0\n            embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive)\n            yield CfgBranch(\n                positive=positive,\n                embeds=embeds,\n                mask=mask,\n                pooled=pooled,\n                cond_latents=cond,\n            )\n        else:\n            pos_embeds, pos_mask, pos_pooled, pos_cond = (\n                prompt_data.get_cfg_branch_data(positive=True)\n            )\n            yield CfgBranch(\n                positive=True,\n                embeds=pos_embeds,\n                mask=pos_mask,\n                pooled=pos_pooled,\n                cond_latents=pos_cond,\n            )\n            neg_embeds, neg_mask, neg_pooled, neg_cond = (\n                prompt_data.get_cfg_branch_data(positive=False)\n            )\n            yield CfgBranch(\n                positive=False,\n                embeds=neg_embeds,\n                mask=neg_mask,\n                pooled=neg_pooled,\n                cond_latents=neg_cond,\n            )\n\n    def _combine_cfg_results(self, results: list[tuple[bool, mx.array]]) -> mx.array:\n        if len(results) == 1:\n            positive, noise = results[0]\n            if self.cfg_parallel and self.is_last_stage:\n                # TODO(ciaran): try to remove\n                mx.eval(noise)\n                return self._exchange_and_apply_guidance(noise, positive)\n            return noise\n\n        noise_neg = next(n for p, n in results if not p)\n        noise_pos = next(n for p, n in results if p)\n        return self._apply_guidance(noise_pos, noise_neg)\n\n    def _exchange_and_apply_guidance(\n        self, noise: mx.array, is_positive: bool\n    ) -> mx.array:\n        assert self.group is not None\n        assert self.cfg_peer_rank is not None\n\n        if is_positive:\n            noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)\n            mx.async_eval(noise)\n            noise_neg = mx.distributed.recv_like(\n                noise, self.cfg_peer_rank, group=self.group\n            )\n            mx.eval(noise_neg)\n            noise_pos = noise\n        else:\n            noise_pos = mx.distributed.recv_like(\n                noise, self.cfg_peer_rank, group=self.group\n            )\n            mx.eval(noise_pos)\n            noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)\n            mx.async_eval(noise)\n            noise_neg = noise\n\n        return self._apply_guidance(noise_pos, noise_neg)\n\n    def _apply_guidance(self, noise_pos: mx.array, noise_neg: mx.array) -> mx.array:\n        scale = self._get_effective_guidance_scale()\n        assert scale is not None\n        return self.adapter.apply_guidance(noise_pos, noise_neg, scale)\n\n    def _ensure_wrappers(\n        self,\n        text_seq_len: int,\n        encoder_hidden_states_mask: mx.array | None = None,\n    ) -> None:\n        \"\"\"Lazily create block wrappers on first forward pass.\n\n        Wrappers need text_seq_len which is only known after prompt encoding.\n        Re-initializes if text_seq_len changes (e.g., warmup vs real generation).\n        \"\"\"\n        if self._wrappers_initialized and self._current_text_seq_len == text_seq_len:\n            return\n\n        self.joint_block_wrappers = self.adapter.get_joint_block_wrappers(\n            text_seq_len=text_seq_len,\n            encoder_hidden_states_mask=encoder_hidden_states_mask,\n        )\n        self.single_block_wrappers = self.adapter.get_single_block_wrappers(\n            text_seq_len=text_seq_len,\n        )\n        self._wrappers_initialized = True\n        self._current_text_seq_len = text_seq_len\n\n    def _reset_all_caches(self) -> None:\n        \"\"\"Reset KV caches on all wrappers for a new generation.\"\"\"\n        if self.joint_block_wrappers:\n            for wrapper in self.joint_block_wrappers:\n                wrapper.reset_cache()\n        if self.single_block_wrappers:\n            for wrapper in self.single_block_wrappers:\n                wrapper.reset_cache()\n\n    def _set_text_seq_len(self, text_seq_len: int) -> None:\n        if self.joint_block_wrappers:\n            for wrapper in self.joint_block_wrappers:\n                wrapper.set_text_seq_len(text_seq_len)\n        if self.single_block_wrappers:\n            for wrapper in self.single_block_wrappers:\n                wrapper.set_text_seq_len(text_seq_len)\n\n    def _calculate_capture_steps(\n        self,\n        partial_images: int,\n        init_time_step: int,\n        num_inference_steps: int,\n    ) -> set[int]:\n        \"\"\"Calculate which timesteps should produce partial images.\n\n        Places the first partial after step 1 for fast initial feedback,\n        then evenly spaces remaining partials with equal gaps between them\n        and from the last partial to the final image.\n\n        Args:\n            partial_images: Number of partial images to capture\n            init_time_step: Starting timestep (for img2img this may not be 0)\n            num_inference_steps: Total inference steps\n\n        Returns:\n            Set of timestep indices to capture\n        \"\"\"\n        if partial_images <= 0:\n            return set()\n\n        total_steps = num_inference_steps - init_time_step\n        if total_steps <= 1:\n            return set()\n\n        if partial_images >= total_steps - 1:\n            return set(range(init_time_step, num_inference_steps - 1))\n\n        capture_steps: set[int] = set()\n\n        first_capture = init_time_step + 1\n        capture_steps.add(first_capture)\n\n        if partial_images == 1:\n            return capture_steps\n\n        final_step = num_inference_steps - 1\n        remaining_range = final_step - first_capture\n\n        for i in range(1, partial_images):\n            step_idx = first_capture + int(i * remaining_range / partial_images)\n            capture_steps.add(step_idx)\n\n        return capture_steps\n\n    def generate_image(\n        self,\n        runtime_config: Config,\n        prompt: str,\n        seed: int,\n        partial_images: int = 0,\n        guidance_override: float | None = None,\n        negative_prompt: str | None = None,\n        num_sync_steps: int = 1,\n    ):\n        \"\"\"Primary entry point for image generation.\n\n        Orchestrates the full generation flow:\n        1. Create runtime config\n        2. Create initial latents\n        3. Encode prompt\n        4. Run diffusion loop (yielding partials if requested)\n        5. Decode to image\n\n        Args:\n            settings: Generation config (steps, height, width)\n            prompt: Text prompt\n            seed: Random seed\n            partial_images: Number of intermediate images to yield (0 for none)\n            guidance_override: Optional override for guidance scale (CFG)\n\n        Yields:\n            Partial images as (GeneratedImage, partial_index, total_partials) tuples\n            Final GeneratedImage\n        \"\"\"\n        self._guidance_override = guidance_override\n        latents = self.adapter.create_latents(seed, runtime_config)\n        prompt_data = self.adapter.encode_prompt(prompt, negative_prompt)\n\n        capture_steps = self._calculate_capture_steps(\n            partial_images=partial_images,\n            init_time_step=runtime_config.init_time_step,\n            num_inference_steps=runtime_config.num_inference_steps,\n        )\n\n        diffusion_gen = self._run_diffusion_loop(\n            latents=latents,\n            prompt_data=prompt_data,\n            runtime_config=runtime_config,\n            seed=seed,\n            prompt=prompt,\n            capture_steps=capture_steps,\n            num_sync_steps=num_sync_steps,\n        )\n\n        partial_index = 0\n        total_partials = len(capture_steps)\n\n        if capture_steps:\n            try:\n                while True:\n                    partial_latents, _step = next(diffusion_gen)\n                    if self.is_last_stage:\n                        partial_image = self.adapter.decode_latents(\n                            partial_latents, runtime_config, seed, prompt\n                        )\n                        yield (partial_image, partial_index, total_partials)\n                        partial_index += 1\n            except StopIteration as e:\n                latents = e.value  # pyright: ignore[reportAny]\n        else:\n            try:\n                while True:\n                    next(diffusion_gen)\n            except StopIteration as e:\n                latents = e.value  # pyright: ignore[reportAny]\n\n        if self.is_last_stage:\n            yield self.adapter.decode_latents(latents, runtime_config, seed, prompt)  # pyright: ignore[reportAny]\n\n    def _run_diffusion_loop(\n        self,\n        latents: mx.array,\n        prompt_data: PromptData,\n        runtime_config: Config,\n        seed: int,\n        prompt: str,\n        num_sync_steps: int,\n        capture_steps: set[int] | None = None,\n    ):\n        if capture_steps is None:\n            capture_steps = set()\n\n        self._reset_all_caches()\n        clear_trace_buffer()\n\n        time_steps = tqdm(range(runtime_config.num_inference_steps))\n\n        ctx = self.adapter.model.callbacks.start(  # pyright: ignore[reportAny]\n            seed=seed, prompt=prompt, config=runtime_config\n        )\n\n        ctx.before_loop(  # pyright: ignore[reportAny]\n            latents=latents,\n        )\n\n        for t in time_steps:\n            try:\n                latents = self._diffusion_step(\n                    t=t,\n                    config=runtime_config,\n                    latents=latents,\n                    prompt_data=prompt_data,\n                    num_sync_steps=num_sync_steps,\n                )\n\n                ctx.in_loop(  # pyright: ignore[reportAny]\n                    t=t,\n                    latents=latents,\n                    time_steps=time_steps,\n                )\n\n                mx.eval(latents)\n\n                if t in capture_steps and self.is_last_stage:\n                    yield (latents, t)\n\n            except KeyboardInterrupt:  # noqa: PERF203\n                ctx.interruption(t=t, latents=latents)  # pyright: ignore[reportAny]\n                raise StopImageGenerationException(\n                    f\"Stopping image generation at step {t + 1}/{len(time_steps)}\"\n                ) from None\n\n        ctx.after_loop(latents=latents)  # pyright: ignore[reportAny]\n\n        return latents\n\n    def _forward_pass(\n        self,\n        latents: mx.array,\n        prompt_embeds: mx.array,\n        pooled_prompt_embeds: mx.array,\n        t: int,\n        config: Config,\n        encoder_hidden_states_mask: mx.array | None = None,\n        cond_image_grid: tuple[int, int, int]\n        | list[tuple[int, int, int]]\n        | None = None,\n        conditioning_latents: mx.array | None = None,\n        kontext_image_ids: mx.array | None = None,\n    ) -> mx.array:\n        \"\"\"Run a single forward pass through the transformer.\n        Args:\n            latents: Input latents (already scaled by caller)\n            prompt_embeds: Text embeddings\n            pooled_prompt_embeds: Pooled text embeddings (Flux) or placeholder (Qwen)\n            t: Current timestep\n            config: Runtime configuration\n            encoder_hidden_states_mask: Attention mask for text (Qwen)\n            cond_image_grid: Conditioning image grid dimensions (Qwen edit)\n            conditioning_latents: Conditioning latents for edit mode\n            kontext_image_ids: Position IDs for Kontext conditioning (Flux Kontext)\n\n        Returns:\n            Noise prediction tensor\n        \"\"\"\n        text_seq_len = prompt_embeds.shape[1]\n\n        self._ensure_wrappers(text_seq_len, encoder_hidden_states_mask)\n\n        if self.joint_block_wrappers and encoder_hidden_states_mask is not None:\n            for wrapper in self.joint_block_wrappers:\n                wrapper.set_encoder_mask(encoder_hidden_states_mask)\n\n        scaled_latents = config.scheduler.scale_model_input(latents, t)  # pyright: ignore[reportAny]\n\n        # For edit mode: concatenate with conditioning latents\n        original_latent_tokens: int = scaled_latents.shape[1]  # pyright: ignore[reportAny]\n        if conditioning_latents is not None:\n            scaled_latents = mx.concatenate(\n                [scaled_latents, conditioning_latents], axis=1\n            )\n\n        hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(\n            scaled_latents, prompt_embeds\n        )\n        text_embeddings = self.adapter.compute_text_embeddings(\n            t, config, pooled_prompt_embeds, hidden_states=hidden_states\n        )\n        rotary_embeddings = self.adapter.compute_rotary_embeddings(\n            prompt_embeds,\n            config,\n            encoder_hidden_states_mask=encoder_hidden_states_mask,\n            cond_image_grid=cond_image_grid,\n            kontext_image_ids=kontext_image_ids,\n        )\n\n        assert self.joint_block_wrappers is not None\n        for wrapper in self.joint_block_wrappers:\n            encoder_hidden_states, hidden_states = wrapper(\n                hidden_states=hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                text_embeddings=text_embeddings,\n                rotary_embeddings=rotary_embeddings,\n            )\n\n        if self.joint_block_wrappers:\n            hidden_states = self.adapter.merge_streams(\n                hidden_states, encoder_hidden_states\n            )\n\n        assert self.single_block_wrappers is not None\n        for wrapper in self.single_block_wrappers:\n            hidden_states = wrapper(\n                hidden_states=hidden_states,\n                text_embeddings=text_embeddings,\n                rotary_embeddings=rotary_embeddings,\n            )\n\n        # Extract image portion\n        hidden_states = hidden_states[:, text_seq_len:, ...]\n\n        # For edit mode: extract only the generated portion (exclude conditioning latents)\n        if conditioning_latents is not None:\n            hidden_states = hidden_states[:, :original_latent_tokens, ...]\n\n        return self.adapter.final_projection(hidden_states, text_embeddings)\n\n    def _diffusion_step(\n        self,\n        t: int,\n        config: Config,\n        latents: mx.array,\n        prompt_data: PromptData,\n        num_sync_steps: int,\n    ) -> mx.array:\n        if self.group is None:\n            return self._single_node_step(t, config, latents, prompt_data)\n        elif (\n            self.pipeline_world_size == 1 or t < config.init_time_step + num_sync_steps\n        ):\n            with trace(name=f\"sync {t}\", rank=self.rank, category=\"sync\"):\n                return self._sync_pipeline_step(\n                    t,\n                    config,\n                    latents,\n                    prompt_data,\n                )\n        else:\n            with trace(name=f\"async {t}\", rank=self.rank, category=\"async\"):\n                return self._async_pipeline_step(\n                    t,\n                    config,\n                    latents,\n                    prompt_data,\n                    is_first_async_step=t == config.init_time_step + num_sync_steps,\n                )\n\n    def _single_node_step(\n        self,\n        t: int,\n        config: Config,\n        latents: mx.array,\n        prompt_data: PromptData,\n    ) -> mx.array:\n        cond_image_grid = prompt_data.cond_image_grid\n        kontext_image_ids = prompt_data.kontext_image_ids\n        results: list[tuple[bool, mx.array]] = []\n\n        for branch in self._get_cfg_branches(prompt_data):\n            # Reset caches before each branch to ensure no state contamination\n            self._reset_all_caches()\n\n            pooled_embeds = (\n                branch.pooled if branch.pooled is not None else branch.embeds\n            )\n\n            noise = self._forward_pass(\n                latents,\n                branch.embeds,\n                pooled_embeds,\n                t=t,\n                config=config,\n                encoder_hidden_states_mask=branch.mask,\n                cond_image_grid=cond_image_grid,\n                conditioning_latents=branch.cond_latents,\n                kontext_image_ids=kontext_image_ids,\n            )\n            results.append((branch.positive, noise))\n\n        noise = self._combine_cfg_results(results)\n        return config.scheduler.step(noise=noise, timestep=t, latents=latents)  # pyright: ignore[reportAny]\n\n    def _create_patches(\n        self,\n        latents: mx.array,\n        config: Config,\n    ) -> tuple[list[mx.array], list[tuple[int, int]]]:\n        latent_height = config.height // 16\n        latent_width = config.width // 16\n\n        patch_heights, _ = calculate_patch_heights(latent_height, self.num_patches)\n        token_indices = calculate_token_indices(patch_heights, latent_width)\n\n        patch_latents = [latents[:, start:end, :] for start, end in token_indices]\n\n        return patch_latents, token_indices\n\n    def _run_sync_pass(\n        self,\n        t: int,\n        config: Config,\n        scaled_hidden_states: mx.array,\n        prompt_embeds: mx.array,\n        pooled_prompt_embeds: mx.array,\n        encoder_hidden_states_mask: mx.array | None,\n        cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] | None,\n        kontext_image_ids: mx.array | None,\n        num_img_tokens: int,\n        original_latent_tokens: int,\n        conditioning_latents: mx.array | None,\n    ) -> mx.array | None:\n        hidden_states = scaled_hidden_states\n        batch_size = hidden_states.shape[0]\n        text_seq_len = prompt_embeds.shape[1]\n        hidden_dim = self.adapter.hidden_dim\n        dtype = scaled_hidden_states.dtype\n\n        self._set_text_seq_len(text_seq_len)\n\n        if self.joint_block_wrappers:\n            for wrapper in self.joint_block_wrappers:\n                wrapper.set_encoder_mask(encoder_hidden_states_mask)\n\n        encoder_hidden_states: mx.array | None = None\n        if self.is_first_stage:\n            hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(\n                hidden_states, prompt_embeds\n            )\n\n        text_embeddings = self.adapter.compute_text_embeddings(\n            t, config, pooled_prompt_embeds, hidden_states=hidden_states\n        )\n        image_rotary_embeddings = self.adapter.compute_rotary_embeddings(\n            prompt_embeds,\n            config,\n            encoder_hidden_states_mask=encoder_hidden_states_mask,\n            cond_image_grid=cond_image_grid,\n            kontext_image_ids=kontext_image_ids,\n        )\n\n        if self.has_joint_blocks:\n            if not self.is_first_stage:\n                assert self.prev_pipeline_rank is not None\n                with trace(\n                    name=f\"recv {self.prev_pipeline_rank}\",\n                    rank=self.rank,\n                    category=\"comms\",\n                ):\n                    hidden_states = mx.distributed.recv(\n                        (batch_size, num_img_tokens, hidden_dim),\n                        dtype,\n                        self.prev_pipeline_rank,\n                        group=self.group,\n                    )\n                    encoder_hidden_states = mx.distributed.recv(\n                        (batch_size, text_seq_len, hidden_dim),\n                        dtype,\n                        self.prev_pipeline_rank,\n                        group=self.group,\n                    )\n                    mx.eval(hidden_states, encoder_hidden_states)\n\n            assert self.joint_block_wrappers is not None\n            assert encoder_hidden_states is not None\n            with trace(\n                name=\"joint_blocks\",\n                rank=self.rank,\n                category=\"compute\",\n            ):\n                for wrapper in self.joint_block_wrappers:\n                    wrapper.set_patch(BlockWrapperMode.CACHING)\n                    encoder_hidden_states, hidden_states = wrapper(\n                        hidden_states=hidden_states,\n                        encoder_hidden_states=encoder_hidden_states,\n                        text_embeddings=text_embeddings,\n                        rotary_embeddings=image_rotary_embeddings,\n                    )\n\n                if EXO_TRACING_ENABLED:\n                    mx.eval(encoder_hidden_states, hidden_states)\n\n        if self.owns_concat_stage:\n            assert encoder_hidden_states is not None\n            concatenated = self.adapter.merge_streams(\n                hidden_states, encoder_hidden_states\n            )\n\n            if self.has_single_blocks or self.is_last_stage:\n                hidden_states = concatenated\n            else:\n                assert self.next_pipeline_rank is not None\n                with trace(\n                    name=f\"send {self.next_pipeline_rank}\",\n                    rank=self.rank,\n                    category=\"comms\",\n                ):\n                    concatenated = mx.distributed.send(\n                        concatenated, self.next_pipeline_rank, group=self.group\n                    )\n                    mx.async_eval(concatenated)\n\n        elif self.has_joint_blocks and not self.is_last_stage:\n            assert encoder_hidden_states is not None\n            assert self.next_pipeline_rank is not None\n            with trace(\n                name=f\"send {self.next_pipeline_rank}\",\n                rank=self.rank,\n                category=\"comms\",\n            ):\n                hidden_states = mx.distributed.send(\n                    hidden_states, self.next_pipeline_rank, group=self.group\n                )\n                encoder_hidden_states = mx.distributed.send(\n                    encoder_hidden_states, self.next_pipeline_rank, group=self.group\n                )\n                mx.async_eval(hidden_states, encoder_hidden_states)\n\n        if self.has_single_blocks:\n            if not self.owns_concat_stage and not self.is_first_stage:\n                assert self.prev_pipeline_rank is not None\n                with trace(\n                    name=f\"recv {self.prev_pipeline_rank}\",\n                    rank=self.rank,\n                    category=\"comms\",\n                ):\n                    hidden_states = mx.distributed.recv(\n                        (batch_size, text_seq_len + num_img_tokens, hidden_dim),\n                        dtype,\n                        self.prev_pipeline_rank,\n                        group=self.group,\n                    )\n                    mx.eval(hidden_states)\n\n            assert self.single_block_wrappers is not None\n            with trace(\n                name=\"single blocks\",\n                rank=self.rank,\n                category=\"compute\",\n            ):\n                for wrapper in self.single_block_wrappers:\n                    wrapper.set_patch(BlockWrapperMode.CACHING)\n                    hidden_states = wrapper(\n                        hidden_states=hidden_states,\n                        text_embeddings=text_embeddings,\n                        rotary_embeddings=image_rotary_embeddings,\n                    )\n\n                if EXO_TRACING_ENABLED:\n                    mx.eval(hidden_states)\n\n            if not self.is_last_stage:\n                assert self.next_pipeline_rank is not None\n                with trace(\n                    name=f\"send {self.next_pipeline_rank}\",\n                    rank=self.rank,\n                    category=\"comms\",\n                ):\n                    hidden_states = mx.distributed.send(\n                        hidden_states, self.next_pipeline_rank, group=self.group\n                    )\n                    mx.async_eval(hidden_states)\n\n        hidden_states = hidden_states[:, text_seq_len:, ...]\n\n        if conditioning_latents is not None:\n            hidden_states = hidden_states[:, :original_latent_tokens, ...]\n\n        if self.is_last_stage:\n            return self.adapter.final_projection(hidden_states, text_embeddings)\n\n        return None\n\n    def _sync_pipeline_step(\n        self,\n        t: int,\n        config: Config,\n        hidden_states: mx.array,\n        prompt_data: PromptData,\n    ) -> mx.array:\n        prev_latents = hidden_states\n        cond_image_grid = prompt_data.cond_image_grid\n        kontext_image_ids = prompt_data.kontext_image_ids\n\n        scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t)  # pyright: ignore[reportAny]\n        original_latent_tokens: int = scaled_hidden_states.shape[1]  # pyright: ignore[reportAny]\n\n        results: list[tuple[bool, mx.array]] = []\n\n        for branch in self._get_cfg_branches(prompt_data):\n            pooled_embeds = (\n                branch.pooled if branch.pooled is not None else branch.embeds\n            )\n\n            cond_latents = branch.cond_latents\n            if cond_latents is not None:\n                num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]\n            else:\n                num_img_tokens = original_latent_tokens\n\n            step_latents: mx.array = scaled_hidden_states  # pyright: ignore[reportAny]\n            if self.is_first_stage and cond_latents is not None:\n                step_latents = mx.concatenate([step_latents, cond_latents], axis=1)\n\n            text_seq_len = branch.embeds.shape[1]\n            self._ensure_wrappers(text_seq_len, branch.mask)\n\n            noise = self._run_sync_pass(\n                t,\n                config,\n                step_latents,\n                branch.embeds,\n                pooled_embeds,\n                branch.mask,\n                cond_image_grid,\n                kontext_image_ids,\n                num_img_tokens,\n                original_latent_tokens,\n                cond_latents,\n            )\n\n            if self.is_last_stage:\n                assert noise is not None\n                results.append((branch.positive, noise))\n\n        if self.is_last_stage:\n            noise = self._combine_cfg_results(results)\n\n            hidden_states = config.scheduler.step(  # pyright: ignore[reportAny]\n                noise=noise, timestep=t, latents=prev_latents\n            )\n\n            if not self.is_first_stage:\n                hidden_states = mx.distributed.send(\n                    hidden_states, self.first_pipeline_rank, group=self.group\n                )\n                mx.async_eval(hidden_states)\n\n        elif self.is_first_stage:\n            hidden_states = mx.distributed.recv_like(\n                prev_latents, src=self.last_pipeline_rank, group=self.group\n            )\n            mx.eval(hidden_states)\n\n        else:\n            hidden_states = prev_latents\n\n        return hidden_states\n\n    def _async_pipeline_step(\n        self,\n        t: int,\n        config: Config,\n        latents: mx.array,\n        prompt_data: PromptData,\n        is_first_async_step: bool,\n    ) -> mx.array:\n        patch_latents, token_indices = self._create_patches(latents, config)\n        cond_image_grid = prompt_data.cond_image_grid\n        kontext_image_ids = prompt_data.kontext_image_ids\n\n        prev_patch_latents = [p for p in patch_latents]\n\n        encoder_hidden_states: mx.array | None = None\n\n        for patch_idx in range(len(patch_latents)):\n            patch = patch_latents[patch_idx]\n\n            if (\n                self.is_first_stage\n                and not self.is_last_stage\n                and not is_first_async_step\n            ):\n                with trace(\n                    name=f\"recv {self.last_pipeline_rank}\",\n                    rank=self.rank,\n                    category=\"comms\",\n                ):\n                    patch = mx.distributed.recv_like(\n                        patch, src=self.last_pipeline_rank, group=self.group\n                    )\n                    mx.eval(patch)\n\n            results: list[tuple[bool, mx.array]] = []\n\n            for branch in self._get_cfg_branches(prompt_data):\n                pooled_embeds = (\n                    branch.pooled if branch.pooled is not None else branch.embeds\n                )\n\n                text_seq_len = branch.embeds.shape[1]\n                self._ensure_wrappers(text_seq_len, branch.mask)\n                self._set_text_seq_len(text_seq_len)\n\n                if self.joint_block_wrappers:\n                    for wrapper in self.joint_block_wrappers:\n                        wrapper.set_encoder_mask(branch.mask)\n\n                text_embeddings = self.adapter.compute_text_embeddings(\n                    t, config, pooled_embeds\n                )\n                image_rotary_embeddings = self.adapter.compute_rotary_embeddings(\n                    branch.embeds,\n                    config,\n                    encoder_hidden_states_mask=branch.mask,\n                    cond_image_grid=cond_image_grid,\n                    kontext_image_ids=kontext_image_ids,\n                )\n\n                noise, encoder_hidden_states = self._run_single_patch_pass(\n                    patch=patch,\n                    patch_idx=patch_idx,\n                    token_indices=token_indices[patch_idx],\n                    prompt_embeds=branch.embeds,\n                    text_embeddings=text_embeddings,\n                    image_rotary_embeddings=image_rotary_embeddings,\n                    encoder_hidden_states=encoder_hidden_states,\n                )\n\n                if self.is_last_stage:\n                    assert noise is not None\n                    results.append((branch.positive, noise))\n\n            if self.is_last_stage:\n                noise = self._combine_cfg_results(results)\n\n                patch_latents[patch_idx] = config.scheduler.step(  # pyright: ignore[reportAny]\n                    noise=noise,\n                    timestep=t,\n                    latents=prev_patch_latents[patch_idx],\n                )\n\n                if not self.is_first_stage and t != config.num_inference_steps - 1:\n                    with trace(\n                        name=f\"send {self.first_pipeline_rank}\",\n                        rank=self.rank,\n                        category=\"comms\",\n                    ):\n                        patch_latents[patch_idx] = mx.distributed.send(\n                            patch_latents[patch_idx],\n                            self.first_pipeline_rank,\n                            group=self.group,\n                        )\n                        mx.async_eval(patch_latents[patch_idx])\n\n        return mx.concatenate(patch_latents, axis=1)\n\n    def _run_single_patch_pass(\n        self,\n        patch: mx.array,\n        patch_idx: int,\n        token_indices: tuple[int, int],\n        prompt_embeds: mx.array,\n        text_embeddings: mx.array,\n        image_rotary_embeddings: RotaryEmbeddings,\n        encoder_hidden_states: mx.array | None,\n    ) -> tuple[mx.array | None, mx.array | None]:\n        \"\"\"Process a single patch through the forward pipeline.\n\n        Handles stage-to-stage communication (stage i -> stage i+1).\n        Ring communication (last stage -> first stage) is handled by the caller.\n\n        Args:\n            patch: The patch latents to process\n            patch_idx: Index of this patch (0-indexed)\n            token_indices: (start_token, end_token) for this patch\n            prompt_embeds: Text embeddings (for compute_embeddings on first stage)\n            text_embeddings: Precomputed text embeddings\n            image_rotary_embeddings: Precomputed rotary embeddings\n            encoder_hidden_states: Encoder hidden states (passed between patches)\n\n        Returns:\n            (noise_prediction, encoder_hidden_states) - noise is None for non-last stages\n        \"\"\"\n        start_token, end_token = token_indices\n        batch_size = patch.shape[0]\n        text_seq_len = prompt_embeds.shape[1]\n        hidden_dim = self.adapter.hidden_dim\n\n        if self.has_joint_blocks:\n            if not self.is_first_stage:\n                assert self.prev_pipeline_rank is not None\n                patch_len = patch.shape[1]\n                with trace(\n                    name=f\"recv {self.prev_pipeline_rank}\",\n                    rank=self.rank,\n                    category=\"comms\",\n                ):\n                    patch = mx.distributed.recv(\n                        (batch_size, patch_len, hidden_dim),\n                        patch.dtype,\n                        self.prev_pipeline_rank,\n                        group=self.group,\n                    )\n                    mx.eval(patch)\n\n                if patch_idx == 0:\n                    with trace(\n                        name=f\"recv {self.prev_pipeline_rank}\",\n                        rank=self.rank,\n                        category=\"comms\",\n                    ):\n                        encoder_hidden_states = mx.distributed.recv(\n                            (batch_size, text_seq_len, hidden_dim),\n                            patch.dtype,\n                            self.prev_pipeline_rank,\n                            group=self.group,\n                        )\n                        mx.eval(encoder_hidden_states)\n\n            if self.is_first_stage:\n                patch, encoder_hidden_states = self.adapter.compute_embeddings(\n                    patch, prompt_embeds\n                )\n\n            assert self.joint_block_wrappers is not None\n            assert encoder_hidden_states is not None\n            with trace(\n                name=f\"joint patch {patch_idx}\",\n                rank=self.rank,\n                category=\"compute\",\n            ):\n                for wrapper in self.joint_block_wrappers:\n                    wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)\n                    encoder_hidden_states, patch = wrapper(\n                        hidden_states=patch,\n                        encoder_hidden_states=encoder_hidden_states,\n                        text_embeddings=text_embeddings,\n                        rotary_embeddings=image_rotary_embeddings,\n                    )\n\n                if EXO_TRACING_ENABLED:\n                    mx.eval(encoder_hidden_states, patch)\n\n        if self.owns_concat_stage:\n            assert encoder_hidden_states is not None\n            patch_concat = self.adapter.merge_streams(patch, encoder_hidden_states)\n\n            if self.has_single_blocks or self.is_last_stage:\n                patch = patch_concat\n            else:\n                assert self.next_pipeline_rank is not None\n                with trace(\n                    name=f\"send {self.next_pipeline_rank}\",\n                    rank=self.rank,\n                    category=\"comms\",\n                ):\n                    patch_concat = mx.distributed.send(\n                        patch_concat, self.next_pipeline_rank, group=self.group\n                    )\n                    mx.async_eval(patch_concat)\n\n        elif self.has_joint_blocks and not self.is_last_stage:\n            assert self.next_pipeline_rank is not None\n            with trace(\n                name=f\"send {self.next_pipeline_rank}\",\n                rank=self.rank,\n                category=\"comms\",\n            ):\n                patch = mx.distributed.send(\n                    patch, self.next_pipeline_rank, group=self.group\n                )\n                mx.async_eval(patch)\n\n            if patch_idx == 0:\n                assert encoder_hidden_states is not None\n                with trace(\n                    name=f\"send {self.next_pipeline_rank}\",\n                    rank=self.rank,\n                    category=\"comms\",\n                ):\n                    encoder_hidden_states = mx.distributed.send(\n                        encoder_hidden_states, self.next_pipeline_rank, group=self.group\n                    )\n                    mx.async_eval(encoder_hidden_states)\n\n        if self.has_single_blocks:\n            if not self.owns_concat_stage and not self.is_first_stage:\n                assert self.prev_pipeline_rank is not None\n                patch_len = patch.shape[1]\n                with trace(\n                    name=f\"recv {self.prev_pipeline_rank}\",\n                    rank=self.rank,\n                    category=\"comms\",\n                ):\n                    patch = mx.distributed.recv(\n                        (batch_size, text_seq_len + patch_len, hidden_dim),\n                        patch.dtype,\n                        self.prev_pipeline_rank,\n                        group=self.group,\n                    )\n                    mx.eval(patch)\n\n            assert self.single_block_wrappers is not None\n            with trace(\n                name=f\"single patch {patch_idx}\",\n                rank=self.rank,\n                category=\"compute\",\n            ):\n                for wrapper in self.single_block_wrappers:\n                    wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)\n                    patch = wrapper(\n                        hidden_states=patch,\n                        text_embeddings=text_embeddings,\n                        rotary_embeddings=image_rotary_embeddings,\n                    )\n\n                if EXO_TRACING_ENABLED:\n                    mx.eval(patch)\n\n            if not self.is_last_stage:\n                assert self.next_pipeline_rank is not None\n                with trace(\n                    name=f\"send {self.next_pipeline_rank}\",\n                    rank=self.rank,\n                    category=\"comms\",\n                ):\n                    patch = mx.distributed.send(\n                        patch, self.next_pipeline_rank, group=self.group\n                    )\n                    mx.async_eval(patch)\n\n        noise: mx.array | None = None\n        if self.is_last_stage:\n            patch_img_only = patch[:, text_seq_len:, :]\n            noise = self.adapter.final_projection(patch_img_only, text_embeddings)\n\n        return noise, encoder_hidden_states\n"
  },
  {
    "path": "src/exo/worker/engines/mlx/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/worker/engines/mlx/auto_parallel.py",
    "content": "import os\nimport threading\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Callable\nfrom functools import partial\nfrom inspect import signature\nfrom typing import TYPE_CHECKING, Any, Literal, Protocol, cast\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom mlx.nn.layers.distributed import (\n    shard_inplace,\n    shard_linear,\n    sum_gradients,\n)\nfrom mlx_lm.models.base import (\n    scaled_dot_product_attention,  # pyright: ignore[reportUnknownVariableType]\n)\nfrom mlx_lm.models.cache import ArraysCache, KVCache\nfrom mlx_lm.models.deepseek_v3 import DeepseekV3MLP\nfrom mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model\nfrom mlx_lm.models.deepseek_v32 import DeepseekV32MLP\nfrom mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model\nfrom mlx_lm.models.glm4_moe import Model as Glm4MoeModel\nfrom mlx_lm.models.glm4_moe import MoE\nfrom mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP\nfrom mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel\nfrom mlx_lm.models.gpt_oss import GptOssMoeModel\nfrom mlx_lm.models.gpt_oss import Model as GptOssModel\nfrom mlx_lm.models.kimi_k25 import Model as KimiK25Model\nfrom mlx_lm.models.llama import Model as LlamaModel\nfrom mlx_lm.models.minimax import MiniMaxAttention\nfrom mlx_lm.models.minimax import Model as MiniMaxModel\nfrom mlx_lm.models.ministral3 import Model as Ministral3Model\nfrom mlx_lm.models.nemotron_h import Model as NemotronHModel\nfrom mlx_lm.models.nemotron_h import (\n    NemotronHAttention,\n    NemotronHMamba2Mixer,\n    NemotronHMoE,\n)\nfrom mlx_lm.models.nemotron_h import NemotronHModel as NemotronHInnerModel\nfrom mlx_lm.models.qwen3_5 import DecoderLayer as Qwen3_5DecoderLayer\nfrom mlx_lm.models.qwen3_5 import Model as Qwen3_5TextModel\nfrom mlx_lm.models.qwen3_5 import Qwen3_5TextModel as Qwen3_5TextModelInner\nfrom mlx_lm.models.qwen3_5 import SparseMoeBlock as Qwen3_5SparseMoeBlock\nfrom mlx_lm.models.qwen3_5_moe import Model as Qwen3_5MoeModel\nfrom mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel\nfrom mlx_lm.models.qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeSparseMoeBlock\nfrom mlx_lm.models.qwen3_next import Model as Qwen3NextModel\nfrom mlx_lm.models.qwen3_next import (\n    Qwen3NextDecoderLayer,\n    Qwen3NextGatedDeltaNet,\n    Qwen3NextSparseMoeBlock,\n)\nfrom mlx_lm.models.qwen3_next import Qwen3NextModel as Qwen3NextInnerModel\nfrom mlx_lm.models.step3p5 import Model as Step35Model\nfrom mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP\nfrom mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel\n\nfrom exo.shared.logging import logger\nfrom exo.shared.types.worker.shards import PipelineShardMetadata\n\nif TYPE_CHECKING:\n    from mlx_lm.models.cache import Cache\n\nTimeoutCallback = Callable[[], None]\nLayerLoadedCallback = Callable[[int, int], None]  # (layers_loaded, total_layers)\n\n\n_pending_prefill_sends: list[tuple[mx.array, int, mx.distributed.Group]] = []\n\n\ndef flush_prefill_sends() -> None:\n    for output, dst, group in _pending_prefill_sends:\n        sent = mx.distributed.send(output, dst, group=group)\n        mx.async_eval(sent)\n    _pending_prefill_sends.clear()\n\n\ndef clear_prefill_sends() -> None:\n    # Discard pending sends (e.g. on cancellation).\n    _pending_prefill_sends.clear()\n\n\ndef eval_with_timeout(\n    mlx_item: Any,  # pyright: ignore[reportAny]\n    timeout_seconds: float = 60.0,\n    on_timeout: TimeoutCallback | None = None,\n) -> None:\n    \"\"\"Evaluate MLX item with a hard timeout.\n\n    If on_timeout callback is provided, it will be called before terminating\n    the process. This allows the runner to send a failure event before exit.\n    \"\"\"\n    completed = threading.Event()\n\n    def watchdog() -> None:\n        if not completed.wait(timeout=timeout_seconds):\n            logger.error(\n                f\"mlx_item evaluation timed out after {timeout_seconds:.0f}s. \"\n                \"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. \"\n                \"Terminating process.\"\n            )\n            if on_timeout is not None:\n                on_timeout()\n            os._exit(1)\n\n    watchdog_thread = threading.Thread(target=watchdog, daemon=True)\n    watchdog_thread.start()\n\n    try:\n        mx.eval(mlx_item)  # pyright: ignore[reportAny]\n    finally:\n        completed.set()\n\n\nclass _LayerCallable(Protocol):\n    \"\"\"Structural type that any compatible layer must satisfy.\n\n    We require a single positional input of type ``mx.array`` and an\n    ``mx.array`` output, while permitting arbitrary *args / **kwargs so this\n    protocol matches the vast majority of `mlx.nn.Module` subclasses.\n    \"\"\"\n\n    def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...\n\n\nclass CustomMlxLayer(nn.Module):\n    \"\"\"Base class for replacing an MLX layer with a custom implementation.\"\"\"\n\n    def __init__(self, original_layer: _LayerCallable):\n        super().__init__()\n        dict.__setitem__(self, \"_original_layer\", original_layer)  # pyright: ignore[reportUnknownMemberType]\n\n    @property\n    def original_layer(self) -> _LayerCallable:\n        return cast(_LayerCallable, self[\"_original_layer\"])\n\n    # Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)\n    if not TYPE_CHECKING:\n\n        def __getattr__(self, name):\n            try:\n                return super().__getattr__(name)\n            except AttributeError:\n                original_layer = cast(_LayerCallable, self[\"_original_layer\"])\n                return getattr(original_layer, name)\n\n\nclass PipelineFirstLayer(CustomMlxLayer):\n    def __init__(\n        self,\n        original_layer: _LayerCallable,\n        r: int,\n        group: mx.distributed.Group,\n    ):\n        super().__init__(original_layer)\n        self.r: int = r\n        self.group = group\n        self.is_prefill: bool = False\n\n    def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:\n        if self.r != 0:\n            # We want to avoid GPU timeout errors by evalling the distributed operation\n            # so that it stays on CPU, which does not have a timeout.\n            mx.eval(x)\n            x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)\n            mx.eval(x)\n        return self.original_layer(x, *args, **kwargs)\n\n\nclass PipelineLastLayer(CustomMlxLayer):\n    def __init__(\n        self,\n        original_layer: _LayerCallable,\n        r: int,\n        s: int,\n        group: mx.distributed.Group,\n    ):\n        super().__init__(original_layer)\n        self.r: int = r\n        self.s: int = s\n        self.group = group\n        self.original_layer_signature = signature(self.original_layer.__call__)\n        self.is_prefill: bool = False\n        self.queue_sends: bool = False\n\n    def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:\n        cache = self.original_layer_signature.bind_partial(\n            x, *args, **kwargs\n        ).arguments.get(\"cache\", None)\n\n        output: mx.array = self.original_layer(x, *args, **kwargs)\n\n        # Eval layer output to materialize it before send — this splits the graph\n        # so the send is isolated and the receiving rank's recv can complete.\n        mx.eval(output)\n\n        if self.r != self.s - 1:\n            if self.queue_sends:\n                _pending_prefill_sends.append(\n                    (output, (self.r + 1) % self.s, self.group)\n                )\n            else:\n                output = mx.distributed.send(\n                    output, (self.r + 1) % self.s, group=self.group\n                )\n            if cache is not None:\n                # CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)\n                # doesn't have .keys directly; access via first sub-cache.\n                _cache = cache[0] if hasattr(cache, \"caches\") else cache  # type: ignore\n                if hasattr(_cache, \"keys\"):  # pyright: ignore[reportAny]\n                    _cache.keys = mx.depends(_cache.keys, output)  # type: ignore\n            mx.eval(output)\n            if cache is not None and hasattr(_cache, \"keys\"):  # type: ignore\n                mx.eval(_cache.keys)  # type: ignore\n\n        if not self.is_prefill:\n            output = mx.distributed.all_gather(output, group=self.group)[\n                -output.shape[0] :\n            ]\n            mx.eval(output)\n\n        return output\n\n\ndef set_pipeline_prefill(model: nn.Module, is_prefill: bool) -> None:\n    for layer in model.layers:  # type: ignore\n        if isinstance(layer, (PipelineFirstLayer, PipelineLastLayer)):\n            layer.is_prefill = is_prefill\n\n\ndef set_pipeline_queue_sends(model: nn.Module, queue_sends: bool) -> None:\n    for layer in model.layers:  # type: ignore\n        if isinstance(layer, PipelineLastLayer):\n            layer.queue_sends = queue_sends\n\n\ndef get_inner_model(model: nn.Module) -> nn.Module:\n    inner = getattr(model, \"model\", None)\n    if isinstance(inner, nn.Module):\n        return inner\n\n    inner = getattr(model, \"transformer\", None)\n    if isinstance(inner, nn.Module):\n        return inner\n\n    inner = getattr(model, \"language_model\", None)\n    if isinstance(inner, nn.Module):\n        inner_inner = getattr(inner, \"model\", None)\n        if isinstance(inner_inner, nn.Module):\n            return inner_inner\n\n    inner = getattr(model, \"backbone\", None)\n    if isinstance(inner, nn.Module):\n        return inner\n\n    raise ValueError(\n        \"Model must either have a 'model', 'transformer', or 'backbone' attribute\"\n    )\n\n\ndef get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:\n    # Handle both model.layers and model.h cases\n    layers: list[_LayerCallable]\n    if hasattr(inner_model_instance, \"layers\"):\n        layers = cast(list[_LayerCallable], inner_model_instance.layers)\n    elif hasattr(inner_model_instance, \"h\"):\n        layers = cast(list[_LayerCallable], inner_model_instance.h)\n    else:\n        raise ValueError(\"Model must have either a 'layers' or 'h' attribute\")\n\n    return layers\n\n\ndef _patch_hybrid_cache(\n    model: Qwen3_5TextModel | Qwen3NextModel | NemotronHModel,\n    fa_idx: int,\n    has_full_attn: bool,\n    ssm_idx: int,\n    has_linear: bool,\n) -> None:\n    # Hacks to make make_mask happy.\n    original = model.make_cache\n\n    def patched() -> list[ArraysCache | KVCache]:\n        cache = original()\n        if not has_full_attn:\n            entry = cache[fa_idx]\n            orig_make_mask = entry.make_mask\n            entry.make_mask = lambda n, **_kw: orig_make_mask(n)  # type: ignore\n        if not has_linear:\n            orig_ssm_make_mask = cache[ssm_idx].make_mask\n\n            def _ssm_mask(\n                n: int, **kw: bool | int | None\n            ) -> mx.array | Literal[\"causal\"] | None:\n                return orig_ssm_make_mask(n, **kw) if kw else None\n\n            cache[ssm_idx].make_mask = _ssm_mask  # type: ignore\n        return cache\n\n    model.make_cache = patched\n\n\ndef pipeline_auto_parallel(\n    model: nn.Module,\n    group: mx.distributed.Group,\n    model_shard_meta: PipelineShardMetadata,\n    on_layer_loaded: LayerLoadedCallback | None,\n) -> nn.Module:\n    \"\"\"\n    Automatically parallelize a model across multiple devices.\n    Args:\n    model: The model to parallelize (must have a 'layers' or 'h' property)\n    model_shard_meta: The metadata for the model shard\n    Returns:\n    The parallelized model\n    \"\"\"\n    inner_model_instance: nn.Module = get_inner_model(model)\n\n    layers = get_layers(inner_model_instance)\n\n    start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer\n    device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size\n\n    layers = layers[start_layer:end_layer]\n    total = len(layers)\n    for i, layer in enumerate(layers):\n        mx.eval(layer)  # type: ignore\n        if on_layer_loaded is not None:\n            on_layer_loaded(i, total)\n\n    layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)\n    layers[-1] = PipelineLastLayer(\n        layers[-1],\n        device_rank,\n        world_size,\n        group=group,\n    )\n\n    if isinstance(inner_model_instance, GptOssMoeModel):\n        inner_model_instance.layer_types = inner_model_instance.layer_types[  # type: ignore\n            start_layer:end_layer\n        ]\n        # We can assume the model has at least one layer thanks to placement.\n        # If a layer type doesn't exist, we can set it to 0.\n        inner_model_instance.swa_idx = (\n            0\n            if \"sliding_attention\" not in inner_model_instance.layer_types  # type: ignore\n            else inner_model_instance.layer_types.index(  # type: ignore\n                \"sliding_attention\"\n            )\n        )\n        inner_model_instance.ga_idx = (\n            0\n            if \"full_attention\" not in inner_model_instance.layer_types  # type: ignore\n            else inner_model_instance.layer_types.index(  # type: ignore\n                \"full_attention\"\n            )\n        )\n\n    if isinstance(inner_model_instance, Step35InnerModel):\n        inner_model_instance.num_layers = len(layers)\n        sliding_layers = [\n            i for i, layer in enumerate(layers) if getattr(layer, \"is_sliding\", False)\n        ]\n        full_layers = [\n            i\n            for i, layer in enumerate(layers)\n            if not getattr(layer, \"is_sliding\", True)\n        ]\n        inner_model_instance._swa_idx = 0 if not sliding_layers else sliding_layers[0]\n        inner_model_instance._full_idx = 0 if not full_layers else full_layers[0]\n\n    if isinstance(inner_model_instance, (Qwen3_5TextModelInner, Qwen3NextInnerModel)):\n        full_attn_layers = [\n            i for i, layer in enumerate(layers) if not getattr(layer, \"is_linear\", True)\n        ]\n        linear_layers = [\n            i for i, layer in enumerate(layers) if getattr(layer, \"is_linear\", False)\n        ]\n        inner_model_instance.fa_idx = full_attn_layers[0] if full_attn_layers else 0\n        inner_model_instance.ssm_idx = linear_layers[0] if linear_layers else 0\n        if not full_attn_layers or not linear_layers:\n            _patch_hybrid_cache(\n                cast(Qwen3_5TextModel | Qwen3NextModel, model),\n                fa_idx=inner_model_instance.fa_idx,\n                has_full_attn=bool(full_attn_layers),\n                ssm_idx=inner_model_instance.ssm_idx,\n                has_linear=bool(linear_layers),\n            )\n\n    if isinstance(inner_model_instance, NemotronHInnerModel):\n        # NemotronH uses block_type: \"M\" (Mamba/SSM), \"*\" (Attention), \"E\" (MoE), \"-\" (MLP)\n        # Only \"M\" and \"*\" blocks have cache entries.\n        # Recompute fa_idx and ssm_idx as cache-array indices for the shard's layers.\n        cache_idx = 0\n        fa_idx: int | None = None\n        ssm_idx: int | None = None\n        for layer in layers:\n            block_type = getattr(layer, \"block_type\", None)\n            if block_type == \"*\":\n                if fa_idx is None:\n                    fa_idx = cache_idx\n                cache_idx += 1\n            elif block_type == \"M\":\n                if ssm_idx is None:\n                    ssm_idx = cache_idx\n                cache_idx += 1\n        has_attn = fa_idx is not None\n        has_mamba = ssm_idx is not None\n        inner_model_instance.fa_idx = fa_idx if fa_idx is not None else 0\n        inner_model_instance.ssm_idx = ssm_idx if ssm_idx is not None else 0\n        if not has_attn or not has_mamba:\n            _patch_hybrid_cache(\n                cast(NemotronHModel, model),\n                fa_idx=inner_model_instance.fa_idx,\n                has_full_attn=has_attn,\n                ssm_idx=inner_model_instance.ssm_idx,\n                has_linear=has_mamba,\n            )\n\n    _set_layers(model, layers)\n\n    assert isinstance(layers, list), (\n        \"Expected a list of layers after auto-parallel initialisation\"\n    )\n\n    return patch_pipeline_model(model, group)\n\n\ndef patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:\n    # Patch __call__ on the model's class\n    cls = model.__class__\n    original_call = cls.__call__  # type :ignore\n    call_signature = signature(original_call)  # type :ignore\n\n    def patched_call(\n        self: T,\n        *args: object,\n        **kwargs: object,\n    ) -> mx.array:\n        logits: mx.array = original_call(self, *args, **kwargs)  # type: ignore\n        cache = call_signature.bind_partial(self, *args, **kwargs).arguments.get(\n            \"cache\", None\n        )\n\n        # Add dependency to last cache entry to ensure distributed ops are evaluated\n        if cache is not None:\n            last = cache[-1]  # type: ignore\n            dep_cache = last[0] if hasattr(last, \"caches\") else last  # type: ignore\n            if hasattr(dep_cache, \"keys\") and dep_cache.keys is not None:  # type: ignore\n                dep_cache.keys = mx.depends(dep_cache.keys, logits)  # type: ignore\n\n        return logits\n\n    cls.__call__ = patched_call\n    return model\n\n\ndef patch_tensor_model[T](model: T) -> T:\n    \"\"\"Patch model's __call__ to ensure distributed ops sync during inference.\"\"\"\n    cls = model.__class__\n    original_call = cls.__call__\n    call_signature = signature(original_call)\n\n    def patched_call(\n        self: T,\n        *args: object,\n        **kwargs: object,\n    ) -> mx.array:\n        logits: mx.array = original_call(self, *args, **kwargs)  # pyright: ignore[reportAny]\n        cache = call_signature.bind_partial(self, *args, **kwargs).arguments.get(\n            \"cache\", None\n        )\n\n        # Add dependency to last cache entry to ensure distributed ops are evaluated\n        if cache is not None and len(cache) > 0:  # pyright: ignore[reportAny]\n            last = cache[-1]  # pyright: ignore[reportAny]\n            dep_cache = last[0] if hasattr(last, \"caches\") else last  # pyright: ignore[reportAny]\n            if hasattr(dep_cache, \"keys\"):  # type: ignore\n                dep_cache.keys = mx.depends(dep_cache.keys, logits)  # pyright: ignore[reportAny,reportUnknownMemberType]\n\n        return logits\n\n    cls.__call__ = patched_call\n    return model\n\n\ndef tensor_auto_parallel(\n    model: nn.Module,\n    group: mx.distributed.Group,\n    timeout_seconds: float,\n    on_timeout: TimeoutCallback | None,\n    on_layer_loaded: LayerLoadedCallback | None,\n) -> nn.Module:\n    all_to_sharded_linear = partial(\n        shard_linear,\n        sharding=\"all-to-sharded\",\n        group=group,\n    )\n    sharded_to_all_linear = partial(\n        shard_linear,\n        sharding=\"sharded-to-all\",\n        group=group,\n    )\n\n    segments: int = 1\n\n    def _all_to_sharded(path: str, weight: mx.array):\n        if path.endswith(\"bias\"):\n            logger.info(f\"Sharding bias for {path} - all to sharded\")\n            return weight.ndim - 1, segments\n        return max(weight.ndim - 2, 0), segments\n\n    all_to_sharded_linear_in_place = partial(\n        shard_inplace,\n        sharding=_all_to_sharded,  # type: ignore\n        group=group,\n    )\n\n    n = group.size()\n\n    def _sharded_to_all(path: str, weight: mx.array):\n        if path.endswith(\"bias\"):\n            logger.info(f\"Sharding bias for {path} - sharded to all\")\n            weight /= n\n            return None\n        return -1, segments\n\n    sharded_to_all_linear_in_place = partial(\n        shard_inplace,\n        sharding=_sharded_to_all,  # type: ignore\n        group=group,\n    )\n\n    if isinstance(model, (LlamaModel, Ministral3Model)):\n        tensor_parallel_sharding_strategy = LlamaShardingStrategy(\n            group,\n            all_to_sharded_linear,\n            sharded_to_all_linear,\n            all_to_sharded_linear_in_place,\n            sharded_to_all_linear_in_place,\n        )\n    elif isinstance(model, (DeepseekV3Model, DeepseekV32Model, KimiK25Model)):\n        tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(\n            group,\n            all_to_sharded_linear,\n            sharded_to_all_linear,\n            all_to_sharded_linear_in_place,\n            sharded_to_all_linear_in_place,\n        )\n    elif isinstance(model, MiniMaxModel):\n        tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(\n            group,\n            all_to_sharded_linear,\n            sharded_to_all_linear,\n            all_to_sharded_linear_in_place,\n            sharded_to_all_linear_in_place,\n        )\n    elif isinstance(model, GLM4MoeLiteModel):\n        tensor_parallel_sharding_strategy = GLM4MoeLiteShardingStrategy(\n            group,\n            all_to_sharded_linear,\n            sharded_to_all_linear,\n            all_to_sharded_linear_in_place,\n            sharded_to_all_linear_in_place,\n        )\n    elif isinstance(model, Glm4MoeModel):\n        tensor_parallel_sharding_strategy = Glm4MoeShardingStrategy(\n            group,\n            all_to_sharded_linear,\n            sharded_to_all_linear,\n            all_to_sharded_linear_in_place,\n            sharded_to_all_linear_in_place,\n        )\n    elif isinstance(\n        model, (Qwen3MoeModel, Qwen3NextModel, Qwen3_5TextModel, Qwen3_5MoeModel)\n    ):\n        tensor_parallel_sharding_strategy = QwenShardingStrategy(\n            group,\n            all_to_sharded_linear,\n            sharded_to_all_linear,\n            all_to_sharded_linear_in_place,\n            sharded_to_all_linear_in_place,\n        )\n    elif isinstance(model, GptOssModel):\n        tensor_parallel_sharding_strategy = GptOssShardingStrategy(\n            group,\n            all_to_sharded_linear,\n            sharded_to_all_linear,\n            all_to_sharded_linear_in_place,\n            sharded_to_all_linear_in_place,\n        )\n    elif isinstance(model, Step35Model):\n        tensor_parallel_sharding_strategy = Step35ShardingStrategy(\n            group,\n            all_to_sharded_linear,\n            sharded_to_all_linear,\n            all_to_sharded_linear_in_place,\n            sharded_to_all_linear_in_place,\n        )\n    elif isinstance(model, NemotronHModel):\n        tensor_parallel_sharding_strategy = NemotronHShardingStrategy(\n            group,\n            all_to_sharded_linear,\n            sharded_to_all_linear,\n            all_to_sharded_linear_in_place,\n            sharded_to_all_linear_in_place,\n        )\n    else:\n        raise ValueError(f\"Unsupported model type: {type(model)}\")\n\n    model = tensor_parallel_sharding_strategy.shard_model(\n        model, timeout_seconds, on_timeout, on_layer_loaded\n    )\n    return patch_tensor_model(model)\n\n\nclass TensorParallelShardingStrategy(ABC):\n    def __init__(\n        self,\n        group: mx.distributed.Group,\n        all_to_sharded_linear: Callable[..., nn.Linear],\n        sharded_to_all_linear: Callable[..., nn.Linear],\n        all_to_sharded_linear_in_place: Callable[..., None],\n        sharded_to_all_linear_in_place: Callable[..., None],\n    ):\n        self.all_to_sharded_linear = all_to_sharded_linear\n        self.sharded_to_all_linear = sharded_to_all_linear\n        self.all_to_sharded_linear_in_place = all_to_sharded_linear_in_place\n        self.sharded_to_all_linear_in_place = sharded_to_all_linear_in_place\n        self.group = group\n        self.N = group.size()\n\n    @abstractmethod\n    def shard_model(\n        self,\n        model: nn.Module,\n        timeout_seconds: float,\n        on_timeout: TimeoutCallback | None,\n        on_layer_loaded: LayerLoadedCallback | None,\n    ) -> nn.Module: ...\n\n\nclass LlamaShardingStrategy(TensorParallelShardingStrategy):\n    def shard_model(\n        self,\n        model: nn.Module,\n        timeout_seconds: float,\n        on_timeout: TimeoutCallback | None,\n        on_layer_loaded: LayerLoadedCallback | None,\n    ) -> nn.Module:\n        model = cast(LlamaModel, model)\n        total = len(model.layers)\n        for i, layer in enumerate(model.layers):\n            # Force load weights before sharding to avoid FAST_SYNCH deadlock\n            eval_with_timeout(layer.parameters(), timeout_seconds / total, on_timeout)\n            layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)\n            layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)\n            layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)\n            layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)\n            layer.self_attn.n_heads //= self.N\n            if layer.self_attn.n_kv_heads is not None:\n                layer.self_attn.n_kv_heads //= self.N\n\n            layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)\n            layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)\n            layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)\n            mx.eval(layer)\n            if on_layer_loaded is not None:\n                on_layer_loaded(i, total)\n        return model\n\n\ndef _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:\n    inner_model_instance = get_inner_model(model)\n    if hasattr(inner_model_instance, \"layers\"):\n        inner_model_instance.layers = layers\n\n        # Update DeepSeek V3 specific parameters when layers are shrunk\n        if isinstance(\n            model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel, KimiK25Model)\n        ) and hasattr(inner_model_instance, \"num_layers\"):\n            logger.info(\n                f\"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}\"\n            )\n            inner_model_instance.start_idx = 0\n            inner_model_instance.end_idx = len(layers)\n            inner_model_instance.num_layers = len(layers)\n        elif isinstance(model, Qwen3MoeModel):\n            logger.info(\n                f\"Setting num_hidden_layers to {len(layers)} for model {model.model.__class__.__name__}\"\n            )\n            inner_model_instance.num_hidden_layers = len(layers)\n    elif hasattr(inner_model_instance, \"h\"):\n        inner_model_instance.h = layers\n    else:\n        raise ValueError(\"Model must have either a 'layers' or 'h' attribute\")\n\n\nclass DeepSeekShardingStrategy(TensorParallelShardingStrategy):\n    def shard_model(\n        self,\n        model: nn.Module,\n        timeout_seconds: float,\n        on_timeout: TimeoutCallback | None,\n        on_layer_loaded: LayerLoadedCallback | None,\n    ) -> nn.Module:\n        model = cast(DeepseekV3Model, model)\n        total = len(model.layers)\n\n        for i, layer in enumerate(model.layers):\n            eval_with_timeout(layer.parameters(), timeout_seconds / total, on_timeout)\n\n            # Shard the self attention\n            if layer.self_attn.q_lora_rank is None:\n                layer.self_attn.q_proj = self.all_to_sharded_linear(\n                    layer.self_attn.q_proj\n                )\n            else:\n                layer.self_attn.q_b_proj = self.all_to_sharded_linear(\n                    layer.self_attn.q_b_proj\n                )\n\n            layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)\n            layer.self_attn.num_heads //= self.N\n\n            # Logic from upstream mlx\n            num_heads = layer.self_attn.num_heads\n            sh = self.group.rank() * num_heads\n            eh = sh + num_heads\n\n            def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:\n                return w[sh:eh]\n\n            layer.self_attn.embed_q.apply(shard_heads)\n            layer.self_attn.unembed_out.apply(shard_heads)\n\n            # Shard the MLP\n            if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):\n                layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)\n                layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)\n                layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)\n\n            # Shard the MoE.\n            else:\n                if getattr(layer.mlp, \"shared_experts\", None) is not None:\n                    self.all_to_sharded_linear_in_place(\n                        layer.mlp.shared_experts.gate_proj\n                    )\n                    self.sharded_to_all_linear_in_place(\n                        layer.mlp.shared_experts.down_proj\n                    )\n                    self.all_to_sharded_linear_in_place(\n                        layer.mlp.shared_experts.up_proj\n                    )\n                self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)\n                self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)\n                self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)\n                layer.mlp = ShardedMoE(layer.mlp)  # type: ignore\n                layer.mlp.sharding_group = self.group\n\n            mx.eval(layer)\n            if on_layer_loaded is not None:\n                on_layer_loaded(i, total)\n\n        return model\n\n\nclass ShardedMoE(CustomMlxLayer):\n    \"\"\"Wraps any MoE layer with distributed sum_gradients / all_sum.\"\"\"\n\n    def __init__(self, layer: _LayerCallable):\n        super().__init__(layer)\n        self.sharding_group: mx.distributed.Group | None = None\n\n    def __call__(self, x: mx.array) -> mx.array:\n        if self.sharding_group is not None:\n            x = sum_gradients(self.sharding_group)(x)\n        y = self.original_layer.__call__(x)\n        if self.sharding_group is not None:\n            y = mx.distributed.all_sum(y, group=self.sharding_group)\n        return y\n\n\nclass GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):\n    def shard_model(\n        self,\n        model: nn.Module,\n        timeout_seconds: float,\n        on_timeout: TimeoutCallback | None,\n        on_layer_loaded: LayerLoadedCallback | None,\n    ) -> nn.Module:\n        model = cast(GLM4MoeLiteModel, model)\n        total = len(model.layers)  # type: ignore\n        for i, layer in enumerate(model.layers):  # type: ignore\n            layer = cast(Glm4MoeLiteDecoderLayer, layer)\n            eval_with_timeout(\n                layer.parameters(),\n                timeout_seconds / total,\n                on_timeout,\n            )\n            if layer.self_attn.q_lora_rank is None:  # type: ignore\n                layer.self_attn.q_proj = self.all_to_sharded_linear(\n                    layer.self_attn.q_proj\n                )\n            else:\n                layer.self_attn.q_b_proj = self.all_to_sharded_linear(\n                    layer.self_attn.q_b_proj\n                )\n\n            layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)\n            layer.self_attn.num_heads //= self.N\n\n            # Logic from upstream mlx\n            num_heads = layer.self_attn.num_heads\n            sh = self.group.rank() * num_heads\n            eh = sh + num_heads\n\n            def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:\n                return w[sh:eh]\n\n            layer.self_attn.embed_q.apply(shard_heads)\n            layer.self_attn.unembed_out.apply(shard_heads)\n\n            if isinstance(layer.mlp, Glm4MoeLiteMLP):\n                layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)\n                layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)\n                layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)\n\n            else:\n                if getattr(layer.mlp, \"shared_experts\", None) is not None:\n                    self.all_to_sharded_linear_in_place(\n                        layer.mlp.shared_experts.gate_proj\n                    )\n                    self.sharded_to_all_linear_in_place(\n                        layer.mlp.shared_experts.down_proj\n                    )\n                    self.all_to_sharded_linear_in_place(\n                        layer.mlp.shared_experts.up_proj\n                    )\n                self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)\n                self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)\n                self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)\n                layer.mlp = ShardedMoE(layer.mlp)  # type: ignore\n                layer.mlp.sharding_group = self.group  # type: ignore\n            mx.eval(layer)\n            if on_layer_loaded is not None:\n                on_layer_loaded(i, total)\n\n        return model\n\n\nclass WrappedMiniMaxAttention(CustomMlxLayer):\n    def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):\n        super().__init__(layer)\n        self.group = group\n\n    def __call__(\n        self,\n        x: mx.array,\n        mask: mx.array | None = None,\n        cache: \"Cache | None\" = None,\n    ) -> mx.array:\n        batch_dim, seq_dim, _ = x.shape\n\n        self._original_layer = cast(MiniMaxAttention, self.original_layer)  # type: ignore\n\n        queries: mx.array = self._original_layer.q_proj(x)\n        keys: mx.array = self._original_layer.k_proj(x)\n        values: mx.array = self._original_layer.v_proj(x)\n\n        if getattr(self, \"use_qk_norm\", False):\n            q_dim = queries.shape[-1]\n            k_dim = keys.shape[-1]\n            n = self.group.size()\n\n            qk = mx.concatenate(\n                [queries, keys], axis=-1\n            )  # (batch_dim, seq_dim, q_dim + k_dim)\n            qk = mx.distributed.all_gather(\n                qk, group=self.group\n            )  # (n*batch_dim, seq_dim, q_dim + k_dim)\n\n            qk = qk.reshape(n, batch_dim, seq_dim, q_dim + k_dim).transpose(1, 2, 0, 3)\n            queries = qk[..., :q_dim].reshape(\n                batch_dim, seq_dim, -1\n            )  # (batch_dim, seq_dim, n * q_dim)\n            keys = qk[..., q_dim:].reshape(\n                batch_dim, seq_dim, -1\n            )  # (batch_dim, seq_dim, n * k_dim)\n\n            queries = self._original_layer.q_norm(queries)\n            keys = self._original_layer.k_norm(keys)\n\n            # Split back and take this rank's portion\n            queries = mx.split(queries, n, axis=-1)[self.group.rank()]\n            keys = mx.split(keys, n, axis=-1)[self.group.rank()]\n\n        queries = queries.reshape(\n            batch_dim, seq_dim, self._original_layer.num_attention_heads, -1\n        ).transpose(0, 2, 1, 3)\n        keys = keys.reshape(\n            batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1\n        ).transpose(0, 2, 1, 3)\n        values = values.reshape(\n            batch_dim, seq_dim, self._original_layer.num_key_value_heads, -1\n        ).transpose(0, 2, 1, 3)\n\n        if cache is not None:\n            queries = self._original_layer.rope(queries, offset=cache.offset)\n            keys = self._original_layer.rope(keys, offset=cache.offset)\n            keys, values = cache.update_and_fetch(keys, values)\n        else:\n            queries = self._original_layer.rope(queries)\n            keys = self._original_layer.rope(keys)\n\n        output = scaled_dot_product_attention(\n            queries,\n            keys,\n            values,\n            cache=cache,\n            scale=self._original_layer.scale,  # type: ignore\n            mask=mask,\n        )\n\n        output = output.transpose(0, 2, 1, 3).reshape(batch_dim, seq_dim, -1)\n\n        return self._original_layer.o_proj(output)\n\n\nclass MiniMaxShardingStrategy(TensorParallelShardingStrategy):\n    def shard_model(\n        self,\n        model: nn.Module,\n        timeout_seconds: float,\n        on_timeout: TimeoutCallback | None,\n        on_layer_loaded: LayerLoadedCallback | None,\n    ) -> nn.Module:\n        model = cast(MiniMaxModel, model)\n        total = len(model.layers)\n        for i, layer in enumerate(model.layers):\n            eval_with_timeout(layer.parameters(), timeout_seconds / total, on_timeout)\n            # Shard the self attention\n            layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)\n            layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)\n            layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)\n            layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)\n\n            layer.self_attn.num_attention_heads //= self.N\n            layer.self_attn.num_key_value_heads //= self.N\n\n            layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group)  # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]\n\n            # Shard the MoE.\n            self.all_to_sharded_linear_in_place(\n                layer.block_sparse_moe.switch_mlp.gate_proj\n            )\n            self.sharded_to_all_linear_in_place(\n                layer.block_sparse_moe.switch_mlp.down_proj\n            )\n            self.all_to_sharded_linear_in_place(\n                layer.block_sparse_moe.switch_mlp.up_proj\n            )\n            layer.block_sparse_moe = ShardedMoE(layer.block_sparse_moe)  # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]\n            layer.block_sparse_moe.sharding_group = self.group  # pyright: ignore[reportAttributeAccessIssue]\n            mx.eval(layer)\n            if on_layer_loaded is not None:\n                on_layer_loaded(i, total)\n        return model\n\n\nclass QwenShardingStrategy(TensorParallelShardingStrategy):\n    def shard_model(\n        self,\n        model: nn.Module,\n        timeout_seconds: float,\n        on_timeout: TimeoutCallback | None,\n        on_layer_loaded: LayerLoadedCallback | None,\n    ) -> nn.Module:\n        model = cast(\n            Qwen3MoeModel | Qwen3NextModel | Qwen3_5TextModel | Qwen3_5MoeModel, model\n        )\n        total = len(model.layers)\n        for i, layer in enumerate(model.layers):\n            eval_with_timeout(layer.parameters(), timeout_seconds / total, on_timeout)\n            # Shard the self attention\n            if isinstance(layer, Qwen3MoeDecoderLayer):\n                layer.self_attn.q_proj = self.all_to_sharded_linear(\n                    layer.self_attn.q_proj\n                )\n                layer.self_attn.k_proj = self.all_to_sharded_linear(\n                    layer.self_attn.k_proj\n                )\n                layer.self_attn.v_proj = self.all_to_sharded_linear(\n                    layer.self_attn.v_proj\n                )\n                layer.self_attn.o_proj = self.sharded_to_all_linear(\n                    layer.self_attn.o_proj\n                )\n                layer.self_attn.n_heads //= self.N\n                layer.self_attn.n_kv_heads //= self.N\n            else:\n                assert isinstance(layer, (Qwen3NextDecoderLayer, Qwen3_5DecoderLayer))\n                if hasattr(layer, \"linear_attn\"):\n                    linear_attn = layer.linear_attn\n\n                    if isinstance(linear_attn, Qwen3NextGatedDeltaNet):\n                        # Qwen3-Next: combined projections\n                        linear_attn.in_proj_qkvz = self.all_to_sharded_linear(\n                            linear_attn.in_proj_qkvz\n                        )\n                        linear_attn.in_proj_ba = self.all_to_sharded_linear(\n                            linear_attn.in_proj_ba\n                        )\n                    else:\n                        # Qwen3.5: separate projections\n                        # in_proj_qkv has sections [q(key_dim), k(key_dim), v(value_dim)]\n                        # that must be split section-aware, not as a contiguous block\n                        key_dim = linear_attn.key_dim\n                        value_dim = linear_attn.value_dim\n                        linear_attn.in_proj_qkv = shard_linear(\n                            linear_attn.in_proj_qkv,\n                            \"all-to-sharded\",\n                            segments=[key_dim, key_dim + key_dim],\n                            group=self.group,\n                        )\n                        linear_attn.in_proj_z = self.all_to_sharded_linear(\n                            linear_attn.in_proj_z\n                        )\n                        linear_attn.in_proj_b = self.all_to_sharded_linear(\n                            linear_attn.in_proj_b\n                        )\n                        linear_attn.in_proj_a = self.all_to_sharded_linear(\n                            linear_attn.in_proj_a\n                        )\n                    linear_attn.out_proj = self.sharded_to_all_linear(\n                        linear_attn.out_proj\n                    )\n\n                    # Shard conv1d: depthwise conv with non-contiguous channel slicing.\n                    # Channel layout is [q(key_dim), k(key_dim), v(value_dim)].\n                    # Each rank takes its head-slice from each of the three sections.\n                    rank = self.group.rank()\n                    key_dim = linear_attn.key_dim\n                    value_dim = linear_attn.value_dim\n                    key_dim_shard = key_dim // self.N\n                    value_dim_shard = value_dim // self.N\n\n                    q_idx = mx.arange(rank * key_dim_shard, (rank + 1) * key_dim_shard)\n                    k_idx = mx.arange(\n                        key_dim + rank * key_dim_shard,\n                        key_dim + (rank + 1) * key_dim_shard,\n                    )\n                    v_idx = mx.arange(\n                        2 * key_dim + rank * value_dim_shard,\n                        2 * key_dim + (rank + 1) * value_dim_shard,\n                    )\n                    conv_indices = mx.concatenate([q_idx, k_idx, v_idx])\n                    linear_attn.conv1d.weight = linear_attn.conv1d.weight[conv_indices]\n                    new_conv_dim = key_dim_shard * 2 + value_dim_shard\n                    linear_attn.conv1d.groups = new_conv_dim\n\n                    num_v_shard = linear_attn.num_v_heads // self.N\n                    v_start = rank * num_v_shard\n                    v_end = v_start + num_v_shard\n                    linear_attn.A_log = linear_attn.A_log[v_start:v_end]\n                    linear_attn.dt_bias = linear_attn.dt_bias[v_start:v_end]\n\n                    linear_attn.num_k_heads //= self.N\n                    linear_attn.num_v_heads //= self.N\n                    linear_attn.key_dim = (\n                        linear_attn.head_k_dim * linear_attn.num_k_heads\n                    )\n                    linear_attn.value_dim = (\n                        linear_attn.head_v_dim * linear_attn.num_v_heads\n                    )\n                    linear_attn.conv_dim = (\n                        linear_attn.key_dim * 2 + linear_attn.value_dim\n                    )\n                else:\n                    layer.self_attn.q_proj = self.all_to_sharded_linear(\n                        layer.self_attn.q_proj\n                    )\n                    layer.self_attn.k_proj = self.all_to_sharded_linear(\n                        layer.self_attn.k_proj\n                    )\n                    layer.self_attn.v_proj = self.all_to_sharded_linear(\n                        layer.self_attn.v_proj\n                    )\n                    layer.self_attn.o_proj = self.sharded_to_all_linear(\n                        layer.self_attn.o_proj\n                    )\n                    layer.self_attn.num_attention_heads //= self.N\n                    layer.self_attn.num_key_value_heads //= self.N\n\n            # Shard the MoE.\n            if isinstance(\n                layer.mlp,\n                (\n                    Qwen3MoeSparseMoeBlock,\n                    Qwen3NextSparseMoeBlock,\n                    Qwen3_5SparseMoeBlock,\n                ),\n            ):\n                self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)\n                self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)\n                self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)\n                if isinstance(\n                    layer.mlp, (Qwen3NextSparseMoeBlock, Qwen3_5SparseMoeBlock)\n                ):\n                    self.all_to_sharded_linear_in_place(\n                        layer.mlp.shared_expert.gate_proj\n                    )\n                    self.sharded_to_all_linear_in_place(\n                        layer.mlp.shared_expert.down_proj\n                    )\n                    self.all_to_sharded_linear_in_place(layer.mlp.shared_expert.up_proj)\n                layer.mlp = ShardedMoE(layer.mlp)  # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]\n                layer.mlp.sharding_group = self.group\n\n            # Shard the MLP\n            else:\n                layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)\n                layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)\n                layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)\n\n            mx.eval(layer)\n            if on_layer_loaded is not None:\n                on_layer_loaded(i, total)\n        return model\n\n\nclass Glm4MoeShardingStrategy(TensorParallelShardingStrategy):\n    def shard_model(\n        self,\n        model: nn.Module,\n        timeout_seconds: float,\n        on_timeout: TimeoutCallback | None,\n        on_layer_loaded: LayerLoadedCallback | None,\n    ) -> nn.Module:\n        model = cast(Glm4MoeModel, model)\n        total = len(model.layers)\n        for i, layer in enumerate(model.layers):\n            eval_with_timeout(layer.parameters(), timeout_seconds / total, on_timeout)\n\n            layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)\n            layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)\n            layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)\n            layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)\n            layer.self_attn.n_heads //= self.N\n            layer.self_attn.n_kv_heads //= self.N\n\n            if isinstance(layer.mlp, MoE):\n                self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)\n                self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)\n                self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)\n                if getattr(layer.mlp, \"shared_experts\", None) is not None:\n                    self.all_to_sharded_linear_in_place(\n                        layer.mlp.shared_experts.gate_proj\n                    )\n                    self.sharded_to_all_linear_in_place(\n                        layer.mlp.shared_experts.down_proj\n                    )\n                    self.all_to_sharded_linear_in_place(\n                        layer.mlp.shared_experts.up_proj\n                    )\n                layer.mlp = ShardedMoE(layer.mlp)  # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]\n                layer.mlp.sharding_group = self.group\n\n            else:\n                layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)\n                layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)\n                layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)\n\n            mx.eval(layer)\n            if on_layer_loaded is not None:\n                on_layer_loaded(i, total)\n        return model\n\n\nclass GptOssShardingStrategy(TensorParallelShardingStrategy):\n    def shard_model(\n        self,\n        model: nn.Module,\n        timeout_seconds: float,\n        on_timeout: TimeoutCallback | None,\n        on_layer_loaded: LayerLoadedCallback | None,\n    ) -> nn.Module:\n        model = cast(GptOssMoeModel, model)\n        total = len(model.layers)\n\n        for i, layer in enumerate(model.layers):\n            eval_with_timeout(layer.parameters(), timeout_seconds / total, on_timeout)\n            layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)\n            layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)\n            layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)\n            layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)\n\n            layer.self_attn.num_attention_heads //= self.N\n            layer.self_attn.num_key_value_heads //= self.N\n            layer.self_attn.num_key_value_groups = (\n                layer.self_attn.num_attention_heads\n                // layer.self_attn.num_key_value_heads\n            )\n\n            layer.self_attn.sinks = layer.self_attn.sinks[\n                layer.self_attn.num_attention_heads\n                * self.group.rank() : layer.self_attn.num_attention_heads\n                * (self.group.rank() + 1)\n            ]\n\n            self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)\n            self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)\n            self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)\n\n            layer.mlp = ShardedMoE(layer.mlp)  # type: ignore\n            layer.mlp.sharding_group = self.group  # pyright: ignore[reportAttributeAccessIssue]\n            mx.eval(layer)\n            if on_layer_loaded is not None:\n                on_layer_loaded(i, total)\n        return model\n\n\nclass Step35ShardingStrategy(TensorParallelShardingStrategy):\n    def shard_model(\n        self,\n        model: nn.Module,\n        timeout_seconds: float,\n        on_timeout: TimeoutCallback | None,\n        on_layer_loaded: LayerLoadedCallback | None,\n    ) -> nn.Module:\n        model = cast(Step35Model, model)\n        total = len(model.layers)\n\n        for i, layer in enumerate(model.layers):\n            eval_with_timeout(layer.parameters(), timeout_seconds / total, on_timeout)\n            layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)\n            layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)\n            layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)\n            layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)\n\n            layer.self_attn.num_heads //= self.N\n            layer.self_attn.num_kv_heads //= self.N\n\n            if getattr(layer.self_attn, \"use_head_wise_attn_gate\", False):\n                layer.self_attn.g_proj = self.all_to_sharded_linear(\n                    layer.self_attn.g_proj\n                )\n\n            if isinstance(layer.mlp, Step35MLP):\n                layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)\n                layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)\n                layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)\n            else:\n                layer.mlp.sharding_group = self.group\n                self.all_to_sharded_linear_in_place(layer.mlp.share_expert.gate_proj)\n                self.all_to_sharded_linear_in_place(layer.mlp.share_expert.up_proj)\n                self.sharded_to_all_linear_in_place(layer.mlp.share_expert.down_proj)\n                self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)\n                self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)\n                self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)\n\n            mx.eval(layer)\n            if on_layer_loaded is not None:\n                on_layer_loaded(i, total)\n        return model\n\n\nclass NemotronHShardingStrategy(TensorParallelShardingStrategy):\n    def shard_model(\n        self,\n        model: nn.Module,\n        timeout_seconds: float,\n        on_timeout: TimeoutCallback | None,\n        on_layer_loaded: LayerLoadedCallback | None,\n    ) -> nn.Module:\n        model = cast(NemotronHModel, model)\n        rank = self.group.rank()\n        total = len(model.layers)\n        for i, layer in enumerate(model.layers):\n            eval_with_timeout(layer.parameters(), timeout_seconds / total, on_timeout)\n\n            mixer = layer.mixer\n\n            if isinstance(mixer, NemotronHAttention):\n                mixer.q_proj = self.all_to_sharded_linear(mixer.q_proj)\n                mixer.k_proj = self.all_to_sharded_linear(mixer.k_proj)\n                mixer.v_proj = self.all_to_sharded_linear(mixer.v_proj)\n                mixer.o_proj = self.sharded_to_all_linear(mixer.o_proj)\n                mixer.num_heads //= self.N\n                mixer.num_key_value_heads //= self.N\n\n            elif isinstance(mixer, NemotronHMamba2Mixer):\n                self._shard_mamba2_mixer(mixer, rank)\n\n            elif isinstance(mixer, NemotronHMoE):\n                # Shard routed experts (SwitchMLP uses fc1/fc2)\n                self.all_to_sharded_linear_in_place(mixer.switch_mlp.fc1)\n                self.sharded_to_all_linear_in_place(mixer.switch_mlp.fc2)\n                # Shard shared expert in-place (no all-reduce — ShardedMoE handles that)\n                if hasattr(mixer, \"shared_experts\"):\n                    self.all_to_sharded_linear_in_place(mixer.shared_experts.up_proj)\n                    self.sharded_to_all_linear_in_place(mixer.shared_experts.down_proj)\n                mixer = ShardedMoE(mixer)  # pyright: ignore[reportArgumentType]\n                mixer.sharding_group = self.group\n                layer.mixer = mixer  # pyright: ignore[reportAttributeAccessIssue]\n\n            mx.eval(layer)\n            if on_layer_loaded is not None:\n                on_layer_loaded(i, total)\n        return model\n\n    def _shard_mamba2_mixer(self, mixer: NemotronHMamba2Mixer, rank: int) -> None:\n        \"\"\"Shard the Mamba2 mixer along the head dimension.\"\"\"\n        world_size = self.N\n        num_heads = mixer.num_heads\n        head_dim = mixer.head_dim\n        n_groups = mixer.n_groups\n        ssm_state_size = mixer.ssm_state_size\n        intermediate_size = mixer.intermediate_size  # = num_heads * head_dim\n\n        # Per-rank sizes\n        heads_per_rank = num_heads // world_size\n        groups_per_rank = n_groups // world_size\n        is_per_rank = heads_per_rank * head_dim\n        bc_per_rank = groups_per_rank * ssm_state_size\n\n        # === in_proj: output layout is [gate:IS | conv_ssm:IS | B:NG*SS | C:NG*SS | dt:NH] ===\n        gate_start = 0\n        conv_ssm_start = intermediate_size\n        b_start = 2 * intermediate_size\n        c_start = b_start + n_groups * ssm_state_size\n        dt_start = c_start + n_groups * ssm_state_size\n\n        # Build index tensor for this rank's slice of each section\n        gate_idx = mx.arange(\n            gate_start + rank * is_per_rank, gate_start + (rank + 1) * is_per_rank\n        )\n        conv_ssm_idx = mx.arange(\n            conv_ssm_start + rank * is_per_rank,\n            conv_ssm_start + (rank + 1) * is_per_rank,\n        )\n        b_idx = mx.arange(\n            b_start + rank * bc_per_rank, b_start + (rank + 1) * bc_per_rank\n        )\n        c_idx = mx.arange(\n            c_start + rank * bc_per_rank, c_start + (rank + 1) * bc_per_rank\n        )\n        dt_idx = mx.arange(\n            dt_start + rank * heads_per_rank, dt_start + (rank + 1) * heads_per_rank\n        )\n\n        indices = mx.concatenate([gate_idx, conv_ssm_idx, b_idx, c_idx, dt_idx])\n        mixer.in_proj.weight = mixer.in_proj.weight[indices]\n\n        # === out_proj: input is intermediate_size (sharded) → hidden_size (reduce) ===\n        mixer.out_proj = self.sharded_to_all_linear(mixer.out_proj)\n\n        # === conv1d: depthwise conv on conv_dim channels ===\n        # conv_dim layout: [ssm_hidden:IS | B:NG*SS | C:NG*SS]\n        conv_ssm_idx_local = mx.arange(rank * is_per_rank, (rank + 1) * is_per_rank)\n        conv_b_idx = mx.arange(\n            intermediate_size + rank * bc_per_rank,\n            intermediate_size + (rank + 1) * bc_per_rank,\n        )\n        conv_c_idx = mx.arange(\n            intermediate_size + n_groups * ssm_state_size + rank * bc_per_rank,\n            intermediate_size + n_groups * ssm_state_size + (rank + 1) * bc_per_rank,\n        )\n        conv_indices = mx.concatenate([conv_ssm_idx_local, conv_b_idx, conv_c_idx])\n        mixer.conv1d.weight = mixer.conv1d.weight[conv_indices]\n        new_conv_dim = is_per_rank + 2 * bc_per_rank\n        mixer.conv1d.groups = new_conv_dim\n        if mixer.conv1d.bias is not None:\n            mixer.conv1d.bias = mixer.conv1d.bias[conv_indices]\n\n        # === Per-head parameters ===\n        h_start = rank * heads_per_rank\n        h_end = h_start + heads_per_rank\n        mixer.dt_bias = mixer.dt_bias[h_start:h_end]\n        mixer.A_log = mixer.A_log[h_start:h_end]\n        mixer.D = mixer.D[h_start:h_end]\n\n        # === Norm: weight is intermediate_size ===\n        mixer.norm.weight = mixer.norm.weight[\n            rank * is_per_rank : (rank + 1) * is_per_rank\n        ]\n\n        # === Update dimensions ===\n        mixer.num_heads = heads_per_rank\n        mixer.n_groups = groups_per_rank\n        mixer.intermediate_size = is_per_rank\n        mixer.conv_dim = new_conv_dim\n        mixer.heads_per_group = heads_per_rank // groups_per_rank\n"
  },
  {
    "path": "src/exo/worker/engines/mlx/cache.py",
    "content": "import os\nfrom copy import deepcopy\n\nimport mlx.core as mx\nimport psutil\nfrom mlx_lm.models.cache import (\n    ArraysCache,\n    CacheList,\n    KVCache,\n    QuantizedKVCache,\n    RotatingKVCache,\n)\nfrom mlx_lm.tokenizer_utils import TokenizerWrapper\n\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.mlx import KVCacheType, Model\nfrom exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS\nfrom exo.worker.runner.bootstrap import logger\n\n\n# Fraction of device memory above which LRU eviction kicks in.\n# Smaller machines need more aggressive eviction.\ndef _default_memory_threshold() -> float:\n    total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb\n    if total_gb >= 128:\n        return 0.85\n    if total_gb >= 64:\n        return 0.80\n    if total_gb >= 32:\n        return 0.75\n    return 0.70\n\n\n_MEMORY_THRESHOLD = float(\n    os.environ.get(\"EXO_MEMORY_THRESHOLD\", _default_memory_threshold())\n)\n\n\nclass CacheSnapshot:\n    \"\"\"Snapshot of states at a known token position.\"\"\"\n\n    def __init__(\n        self, states: list[RotatingKVCache | ArraysCache | None], token_count: int\n    ):\n        self.states = states\n        self.token_count = token_count\n\n\ndef snapshot_ssm_states(cache: KVCacheType) -> CacheSnapshot:\n    states: list[ArraysCache | RotatingKVCache | None] = []\n    for c in cache:\n        if isinstance(c, (ArraysCache, RotatingKVCache)):\n            states.append(deepcopy(c))\n        else:\n            states.append(None)\n    token_count = cache_length(cache)\n    return CacheSnapshot(states=states, token_count=token_count)\n\n\ndef _find_nearest_snapshot(\n    snapshots: list[CacheSnapshot],\n    target_token_count: int,\n) -> CacheSnapshot | None:\n    best: CacheSnapshot | None = None\n    for snap in snapshots:\n        if snap.token_count <= target_token_count and (\n            best is None or snap.token_count > best.token_count\n        ):\n            best = snap\n    return best\n\n\ndef has_non_kv_caches(cache: KVCacheType) -> bool:\n    \"\"\"Check if a cache contains any ArraysCache (SSM) entries.\"\"\"\n    return any(isinstance(c, (ArraysCache, RotatingKVCache)) for c in cache)\n\n\nclass KVPrefixCache:\n    def __init__(self, group: mx.distributed.Group | None):\n        self.prompts: list[mx.array] = []  # mx array of tokens (ints)\n        self.caches: list[KVCacheType] = []\n        self._snapshots: list[list[CacheSnapshot] | None] = []\n        self._last_used: list[int] = []  # monotonic counter of last access per entry\n        self._access_counter: int = 0\n        self._group = group\n\n    def clear(self):\n        \"\"\"Clear all cached prompts and caches.\"\"\"\n        self.prompts.clear()\n        self.caches.clear()\n        self._snapshots.clear()\n        self._last_used.clear()\n\n    def add_kv_cache(\n        self,\n        prompt_tokens: mx.array,\n        cache: KVCacheType,\n        ssm_snapshots: list[CacheSnapshot] | None = None,\n    ):\n        \"\"\"Add a new cache entry. Evicts LRU entries if memory is high.\"\"\"\n        self._evict_if_needed()\n        self.prompts.append(prompt_tokens)\n        self.caches.append(deepcopy(cache))\n        self._snapshots.append(ssm_snapshots)\n        self._access_counter += 1\n        self._last_used.append(self._access_counter)\n        logger.info(f\"KV cache added: {len(prompt_tokens)} tokens\")\n\n    def update_kv_cache(\n        self,\n        index: int,\n        prompt_tokens: mx.array,\n        cache: KVCacheType,\n        snapshots: list[CacheSnapshot] | None,\n        restore_pos: int,\n    ):\n        \"\"\"Update an existing cache entry in-place.\"\"\"\n        old_snapshots = self._snapshots[index]\n        merged: list[CacheSnapshot] = []\n        if old_snapshots:\n            merged = [s for s in old_snapshots if s.token_count <= restore_pos]\n        if snapshots:\n            merged.extend(snapshots)\n\n        self.prompts[index] = prompt_tokens\n        self.caches[index] = deepcopy(cache)\n        self._snapshots[index] = merged or None\n        self._access_counter += 1\n        self._last_used[index] = self._access_counter\n        logger.info(f\"KV cache updated (index {index}): {len(prompt_tokens)} tokens\")\n\n    def _get_snapshot(\n        self, entry_index: int, target_token_count: int\n    ) -> tuple[int, CacheSnapshot | None]:\n        if not has_non_kv_caches(self.caches[entry_index]):\n            return target_token_count, None\n\n        snapshots = self._snapshots[entry_index]\n        if not snapshots:\n            return 0, None\n\n        snap = _find_nearest_snapshot(snapshots, target_token_count)\n        if snap is not None:\n            return snap.token_count, snap\n\n        return 0, None\n\n    def get_kv_cache(\n        self,\n        model: Model,\n        prompt_tokens: mx.array,\n    ) -> tuple[KVCacheType, mx.array, int | None]:\n        \"\"\"Get KV cache for prompt, returning remaining tokens to prefill.\n\n        Returns:\n            Tuple of (cache, remaining_tokens, matched_index) where:\n            - cache: KV cache to use for generation\n            - remaining_tokens: tokens that still need prefilling\n            - matched_index: index of the matched entry (None if no match)\n\n        For models with SSM layers (which are ArraysCache in mlx), the cache is trimmed to the\n        nearest SSM snapshot position at or before the match point for correctness.\n        Same for rotating KV Cache.\n        \"\"\"\n        max_length = len(prompt_tokens)\n\n        best_index: int | None = None\n        best_length = 0\n        is_exact = False\n\n        # Find best cache match\n        for i, cached_prompt in enumerate(self.prompts):\n            length = get_prefix_length(prompt_tokens, cached_prompt)\n            if length >= max_length - 1:\n                best_index, best_length = i, length\n                is_exact = True\n                break\n            if length > best_length:\n                best_index, best_length = i, length\n\n        if best_index is None:\n            return make_kv_cache(model), prompt_tokens, None\n\n        # For exact match: trim to max_length-1 so remaining has the last token\n        # For partial match: trim to best_length, remaining has suffix to prefill\n        # This ensures stream_generate always has at least one token to start with\n        has_ssm = has_non_kv_caches(self.caches[best_index])\n        target = (max_length - 1) if is_exact and not has_ssm else best_length\n        restore_pos, restore_snap = self._get_snapshot(best_index, target)\n\n        # No usable snapshot — need fresh cache\n        if restore_snap is None and has_ssm:\n            return make_kv_cache(model), prompt_tokens, None\n\n        prompt_cache = deepcopy(self.caches[best_index])\n        cached_length = cache_length(self.caches[best_index])\n        tokens_to_trim = cached_length - restore_pos\n        if tokens_to_trim > 0:\n            trim_cache(prompt_cache, tokens_to_trim, restore_snap)\n            # Reset cache offset to match trimmed length\n            for c in prompt_cache:\n                if hasattr(c, \"offset\"):\n                    c.offset = restore_pos\n\n        self._access_counter += 1\n        self._last_used[best_index] = self._access_counter\n        remaining = prompt_tokens[restore_pos:]\n\n        return prompt_cache, remaining, best_index\n\n    def _evict_if_needed(self):\n        \"\"\"Evict least recently used entries while memory usage is high.\"\"\"\n        if len(self.caches) == 0:\n            return\n\n        # Evict LRU entries until below threshold\n        while (\n            len(self.caches) > 0\n            and self.get_memory_used_percentage() > _MEMORY_THRESHOLD\n        ):\n            lru_index = self._last_used.index(min(self._last_used))\n            evicted_tokens = len(self.prompts[lru_index])\n            self.prompts.pop(lru_index)\n            self.caches.pop(lru_index)\n            self._snapshots.pop(lru_index)\n            self._last_used.pop(lru_index)\n            logger.info(\n                f\"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage\"\n            )\n\n    def get_memory_used_percentage(self) -> float:\n        local_pressure: float = get_memory_used_percentage()\n\n        if self._group is None:\n            return local_pressure\n\n        all_pressure = mx.distributed.all_gather(\n            mx.array([local_pressure], dtype=mx.float32),\n            group=self._group,\n        )\n        # .item() evals.\n        max_pressure = float(mx.max(all_pressure).item())\n        return max_pressure\n\n\ndef trim_cache(\n    cache: KVCacheType,\n    num_tokens: int,\n    snapshot: CacheSnapshot | None = None,\n) -> None:\n    for i, c in enumerate(cache):\n        if isinstance(c, (ArraysCache, RotatingKVCache)):\n            if snapshot is not None and snapshot.states[i] is not None:\n                cache[i] = deepcopy(snapshot.states[i])  # type: ignore\n            else:\n                c.state = [None] * len(c.state)\n        else:\n            c.trim(num_tokens)\n\n\ndef encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:\n    \"\"\"Encode a prompt string to token array.\n\n    For chat-templated prompts (which have their own structure markers like\n    <|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as\n    that would corrupt the prompt structure.\n    \"\"\"\n    # Chat templates define their own structure - don't add BOS/EOS\n    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)\n    return mx.array(prompt_tokens)\n\n\ndef _entry_length(\n    c: KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList,\n) -> int:\n    # Use .offset attribute which KVCache types have (len() not implemented in older QuantizedKVCache).\n    if hasattr(c, \"offset\"):\n        return c.offset\n    # For CacheList\n    if hasattr(c, \"size\"):\n        return int(c.size())  # type: ignore\n    return 0\n\n\ndef cache_length(cache: KVCacheType) -> int:\n    \"\"\"Get the number of tokens in a KV cache.\"\"\"\n    return max(_entry_length(c) for c in cache)\n\n\ndef get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:\n    \"\"\"Find the length of the common prefix between two token arrays.\"\"\"\n    n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))\n    if n == 0:\n        return 0\n\n    equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)\n    prefix_mask = mx.cumprod(equal)  # stays 1 until first mismatch, then 0 forever\n    return int(mx.sum(prefix_mask).item())\n\n\ndef get_available_memory() -> Memory:\n    mem: int = psutil.virtual_memory().available\n    return Memory.from_bytes(mem)\n\n\ndef get_memory_used_percentage() -> float:\n    mem = psutil.virtual_memory()\n    # percent is 0-100\n    return float(mem.percent / 100)\n\n\ndef make_kv_cache(\n    model: Model, max_kv_size: int | None = None, keep: int = 0\n) -> KVCacheType:\n    assert hasattr(model, \"layers\")\n\n    if hasattr(model, \"make_cache\"):\n        logger.info(\"Using MLX LM's make cache\")\n        return model.make_cache()  # type: ignore\n\n    if max_kv_size is None:\n        if KV_CACHE_BITS is None:\n            logger.info(\"Using default KV cache\")\n            return [KVCache() for _ in model.layers]\n        else:\n            logger.info(\"Using quantized KV cache\")\n            return [\n                QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)\n                for _ in model.layers\n            ]\n    else:\n        logger.info(f\"Using rotating KV cache with {max_kv_size=} with {keep=}\")\n        return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]\n"
  },
  {
    "path": "src/exo/worker/engines/mlx/constants.py",
    "content": "# TODO: Do we want so many constants?\n#  I think we want a lot of these as parameters?\n\nKV_GROUP_SIZE: int | None = 32\nKV_BITS: int | None = None\nATTENTION_KV_BITS: int | None = 4\nMAX_TOKENS: int = 32168\nMAX_KV_SIZE: int | None = 3200\nKEEP_KV_SIZE: int | None = 1600\nQUANTIZE_MODEL_MODE: str | None = \"affine\"\nCACHE_GROUP_SIZE: int = 64\nKV_CACHE_BITS: int | None = None\n\nDEFAULT_TOP_LOGPROBS: int = 5\n\n# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True\nTRUST_REMOTE_CODE: bool = True\n"
  },
  {
    "path": "src/exo/worker/engines/mlx/dsml_encoding.py",
    "content": "import json\nimport re\nfrom typing import Any\n\nfrom mlx_lm.chat_templates import deepseek_v32\n\nfrom exo.api.types import ToolCallItem\n\nBOS_TOKEN: str = deepseek_v32.bos_token\nEOS_TOKEN: str = deepseek_v32.eos_token\nDSML_TOKEN: str = deepseek_v32.dsml_token\nTHINKING_START: str = deepseek_v32.thinking_start_token\nTHINKING_END: str = deepseek_v32.thinking_end_token\nUSER_TOKEN = \"<\\uff5cUser\\uff5c>\"\nASSISTANT_TOKEN = \"<\\uff5cAssistant\\uff5c>\"\nTOOL_CALLS_START = f\"<{DSML_TOKEN}function_calls>\"\nTOOL_CALLS_END = f\"</{DSML_TOKEN}function_calls>\"\nencode_messages = deepseek_v32.encode_messages\n\n_INVOKE_PATTERN = re.compile(\n    rf\"<{re.escape(DSML_TOKEN)}invoke\\s+name=\\\"([^\\\"]+)\\\">\"\n    rf\"(.*?)\"\n    rf\"</{re.escape(DSML_TOKEN)}invoke>\",\n    re.DOTALL,\n)\n\n_PARAM_PATTERN = re.compile(\n    rf\"<{re.escape(DSML_TOKEN)}parameter\\s+name=\\\"([^\\\"]+)\\\"\\s+string=\\\"(true|false)\\\">\"\n    rf\"(.*?)\"\n    rf\"</{re.escape(DSML_TOKEN)}parameter>\",\n    re.DOTALL,\n)\n\n\ndef parse_dsml_output(text: str) -> list[ToolCallItem] | None:\n    \"\"\"Parse DSML function_calls block from model output text.\n\n    Args:\n        text: The text containing the DSML function_calls block\n              (including the start/end markers).\n\n    Returns:\n        List of ToolCallItem, or None if parsing fails.\n    \"\"\"\n    tool_calls: list[ToolCallItem] = []\n\n    for invoke_match in _INVOKE_PATTERN.finditer(text):\n        func_name = invoke_match.group(1)\n        invoke_body = invoke_match.group(2)\n\n        args: dict[str, Any] = {}\n        for param_match in _PARAM_PATTERN.finditer(invoke_body):\n            param_name = param_match.group(1)\n            is_string = param_match.group(2) == \"true\"\n            param_value = param_match.group(3)\n\n            if is_string:\n                args[param_name] = param_value\n            else:\n                try:\n                    args[param_name] = json.loads(param_value)\n                except (json.JSONDecodeError, ValueError):\n                    args[param_name] = param_value\n\n        tool_calls.append(\n            ToolCallItem(\n                name=func_name,\n                arguments=json.dumps(args),\n            )\n        )\n\n    return tool_calls if tool_calls else None\n"
  },
  {
    "path": "src/exo/worker/engines/mlx/generator/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/worker/engines/mlx/generator/batch_generate.py",
    "content": "import time\nfrom dataclasses import dataclass, field\nfrom typing import Callable, cast\n\nimport mlx.core as mx\nfrom mlx_lm.generate import (\n    BatchGenerator as MlxBatchGenerator,\n)\nfrom mlx_lm.models.cache import RotatingKVCache\nfrom mlx_lm.sample_utils import make_logits_processors, make_sampler\nfrom mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper\n\nfrom exo.api.types import (\n    CompletionTokensDetails,\n    FinishReason,\n    GenerationStats,\n    PromptTokensDetails,\n    TopLogprobItem,\n    Usage,\n)\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.mlx import KVCacheType, Model\nfrom exo.shared.types.text_generation import TextGenerationTaskParams\nfrom exo.shared.types.worker.runner_response import GenerationResponse\nfrom exo.worker.engines.mlx.cache import (\n    CacheSnapshot,\n    KVPrefixCache,\n    encode_prompt,\n    make_kv_cache,\n)\nfrom exo.worker.engines.mlx.constants import DEFAULT_TOP_LOGPROBS, MAX_TOKENS\nfrom exo.worker.engines.mlx.generator.generate import (\n    ban_token_ids,\n    eos_ids_from_tokenizer,\n    extract_top_logprobs,\n    prefill,\n)\nfrom exo.worker.engines.mlx.utils_mlx import fix_unmatched_think_end_tokens\nfrom exo.worker.runner.bootstrap import logger\n\n_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5\n\n\ndef _stop_sequences(task_params: TextGenerationTaskParams) -> list[str]:\n    if task_params.stop is None:\n        return []\n    if isinstance(task_params.stop, str):\n        return [task_params.stop]\n    return task_params.stop\n\n\n@dataclass\nclass _EngineTask:\n    uid: int\n    task_params: TextGenerationTaskParams\n    all_prompt_tokens: mx.array\n    prefix_hit_length: int\n    matched_index: int | None\n    cache_snapshots: list[CacheSnapshot] | None\n    detokenizer: StreamingDetokenizer\n    on_generation_token: Callable[[], None] | None = None\n    generated_text_parts: list[str] = field(default_factory=list)\n    potential_stop_sequence_text: str = \"\"\n    completion_tokens: int = 0\n    generation_start_time: float = 0.0\n    in_thinking: bool = False\n    reasoning_tokens: int = 0\n    prefill_tps: float = 0.0\n\n\n@dataclass(eq=False)\nclass ExoBatchGenerator:\n    model: Model\n    tokenizer: TokenizerWrapper\n    group: mx.distributed.Group | None\n    kv_prefix_cache: KVPrefixCache | None\n\n    _exo_gen: MlxBatchGenerator = field(init=False)\n    _active_tasks: dict[int, _EngineTask] = field(default_factory=dict, init=False)\n\n    def __post_init__(self) -> None:\n        self._exo_gen = MlxBatchGenerator(\n            model=self.model,\n            stop_tokens=set(eos_ids_from_tokenizer(self.tokenizer)),\n            prefill_step_size=4096,\n        )\n\n    @property\n    def has_work(self) -> bool:\n        return (\n            bool(self._active_tasks)\n            or bool(self._exo_gen.unprocessed_prompts)\n            or self._exo_gen.active_batch is not None\n        )\n\n    def submit(\n        self,\n        task_params: TextGenerationTaskParams,\n        prompt: str,\n        on_prefill_progress: Callable[[int, int], None] | None = None,\n        distributed_prompt_progress_callback: Callable[[], None] | None = None,\n        on_generation_token: Callable[[], None] | None = None,\n    ) -> int:\n        all_prompt_tokens = encode_prompt(self.tokenizer, prompt)\n        all_prompt_tokens = fix_unmatched_think_end_tokens(\n            all_prompt_tokens, self.tokenizer\n        )\n\n        is_bench = task_params.bench\n\n        prefix_hit_length = 0\n        matched_index: int | None = None\n        prompt_tokens = all_prompt_tokens\n\n        if self.kv_prefix_cache is not None and not is_bench:\n            cache, remaining_tokens, matched_index = self.kv_prefix_cache.get_kv_cache(\n                self.model, all_prompt_tokens\n            )\n            prefix_hit_length = len(all_prompt_tokens) - len(remaining_tokens)\n            if prefix_hit_length > 0:\n                logger.info(\n                    f\"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens \"\n                    f\"cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)\"\n                )\n                prompt_tokens = remaining_tokens\n            else:\n                cache = make_kv_cache(self.model)\n        else:\n            cache = make_kv_cache(self.model)\n\n        seed = task_params.seed if task_params.seed is not None else 42\n        mx.random.seed(seed)\n\n        sampler = make_sampler(\n            temp=task_params.temperature\n            if task_params.temperature is not None\n            else 0.7,\n            top_p=task_params.top_p if task_params.top_p is not None else 1.0,\n            min_p=task_params.min_p if task_params.min_p is not None else 0.05,\n            top_k=task_params.top_k if task_params.top_k is not None else 0,\n        )\n\n        _prefill_tps, _prefill_tokens, cache_snapshots = prefill(\n            self.model,\n            self.tokenizer,\n            sampler,\n            prompt_tokens[:-1],\n            cache,\n            self.group,\n            on_prefill_progress,\n            distributed_prompt_progress_callback,\n        )\n\n        # We need to clamp rotating kv caches to max size so that mlx lm's _merge_caches behaves\n        for c in cache:\n            if (\n                isinstance(c, RotatingKVCache)\n                and c.keys is not None\n                and c.values is not None\n                and c.keys.shape[2] > c.max_size\n            ):\n                trim_size = c.keys.shape[2] - c.max_size\n                c.keys = c._trim(trim_size, c.keys)\n                c.values = c._trim(trim_size, c.values)\n                c._idx = c.max_size\n\n        if not is_bench:\n            self._save_prefix_cache(\n                all_prompt_tokens,\n                list(cache),\n                cache_snapshots,\n                prefix_hit_length,\n                matched_index,\n            )\n\n        last_tokens = prompt_tokens[-2:]\n\n        logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = (\n            make_logits_processors(\n                repetition_penalty=task_params.repetition_penalty,\n                repetition_context_size=task_params.repetition_context_size,\n            )\n        )\n        if is_bench:\n            # Only sample length eos tokens\n            eos_ids = eos_ids_from_tokenizer(self.tokenizer)\n            logits_processors = [ban_token_ids(eos_ids)] + logits_processors\n\n        max_tokens = task_params.max_output_tokens or MAX_TOKENS\n\n        uids = self._exo_gen.insert(\n            prompts=[last_tokens.tolist()],\n            max_tokens=[max_tokens],\n            caches=[list(cache)],\n            samplers=[sampler],\n            logits_processors=[logits_processors],\n        )\n\n        assert len(uids) == 1\n\n        uid = uids[0]\n\n        self._active_tasks[uid] = _EngineTask(\n            uid=uid,\n            task_params=task_params,\n            all_prompt_tokens=all_prompt_tokens,\n            prefix_hit_length=prefix_hit_length,\n            matched_index=matched_index,\n            cache_snapshots=cache_snapshots or None,\n            detokenizer=self.tokenizer.detokenizer,\n            on_generation_token=on_generation_token,\n            generation_start_time=time.perf_counter(),\n            prefill_tps=_prefill_tps,\n        )\n\n        return uid\n\n    def step(self) -> list[tuple[int, GenerationResponse]]:\n        if not self.has_work:\n            return []\n\n        responses = self._exo_gen.next()\n\n        results: list[tuple[int, GenerationResponse]] = []\n\n        for response in responses:\n            if response.uid not in self._active_tasks:\n                logger.warning(\n                    f\"response uid {response.uid} was not found - should be active\"\n                )\n                continue\n\n            state = self._active_tasks[response.uid]\n            if state.on_generation_token is not None:\n                state.on_generation_token()\n            if response.finish_reason != \"stop\":\n                state.detokenizer.add_token(response.token)\n            if response.finish_reason is not None:\n                state.detokenizer.finalize()\n            text = state.detokenizer.last_segment\n            state.completion_tokens += 1\n            state.generated_text_parts.append(text)\n            state.potential_stop_sequence_text += text\n\n            think_start = self.tokenizer.think_start\n            think_end = self.tokenizer.think_end\n            if think_start is not None and text == think_start:\n                state.in_thinking = True\n            elif think_end is not None and text == think_end:\n                state.in_thinking = False\n            if state.in_thinking:\n                state.reasoning_tokens += 1\n\n            finish_reason: FinishReason | None = cast(\n                FinishReason | None, response.finish_reason\n            )\n            task_params = state.task_params\n            stop_sequences = _stop_sequences(task_params)\n            max_stop_len = max((len(s) for s in stop_sequences), default=0)\n\n            if stop_sequences:\n                for stop_seq in stop_sequences:\n                    if stop_seq in state.potential_stop_sequence_text:\n                        stop_index = state.potential_stop_sequence_text.find(stop_seq)\n                        text_before_stop = state.potential_stop_sequence_text[\n                            :stop_index\n                        ]\n                        chunk_start = len(state.potential_stop_sequence_text) - len(\n                            text\n                        )\n                        text = text_before_stop[chunk_start:]\n                        finish_reason = \"stop\"\n                        break\n\n            is_done = finish_reason is not None\n\n            logprob: float | None = None\n            top_logprobs: list[TopLogprobItem] | None = None\n            if task_params.logprobs:\n                logprob, top_logprobs = extract_top_logprobs(\n                    logprobs=response.logprobs,\n                    tokenizer=self.tokenizer,\n                    top_logprobs=task_params.top_logprobs or DEFAULT_TOP_LOGPROBS,\n                    selected_token=response.token,\n                )\n\n            stats: GenerationStats | None = None\n            usage: Usage | None = None\n            if is_done:\n                try:\n                    mlx_stats = self._exo_gen.stats()\n                    generation_tps = mlx_stats.generation_tps\n                except ZeroDivisionError:\n                    generation_elapsed = (\n                        time.perf_counter() - state.generation_start_time\n                    )\n                    generation_tps = (\n                        state.completion_tokens / generation_elapsed\n                        if generation_elapsed > 0\n                        else 0.0\n                    )\n\n                stats = GenerationStats(\n                    prompt_tps=state.prefill_tps,\n                    generation_tps=generation_tps,\n                    prompt_tokens=len(state.all_prompt_tokens),\n                    generation_tokens=state.completion_tokens,\n                    peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),\n                )\n                total_prompt_tokens = len(state.all_prompt_tokens)\n                usage = Usage(\n                    prompt_tokens=total_prompt_tokens,\n                    completion_tokens=state.completion_tokens,\n                    total_tokens=total_prompt_tokens + state.completion_tokens,\n                    prompt_tokens_details=PromptTokensDetails(\n                        cached_tokens=state.prefix_hit_length\n                    ),\n                    completion_tokens_details=CompletionTokensDetails(\n                        reasoning_tokens=state.reasoning_tokens\n                    ),\n                )\n\n            results.append(\n                (\n                    response.uid,\n                    GenerationResponse(\n                        text=text,\n                        token=response.token,\n                        logprob=logprob,\n                        top_logprobs=top_logprobs,\n                        finish_reason=finish_reason,\n                        stats=stats,\n                        usage=usage,\n                    ),\n                )\n            )\n\n            if is_done:\n                del self._active_tasks[response.uid]\n            elif (\n                max_stop_len > 0\n                and len(state.potential_stop_sequence_text) > max_stop_len\n            ):\n                state.potential_stop_sequence_text = state.potential_stop_sequence_text[\n                    -max_stop_len:\n                ]\n\n        return results\n\n    def cancel(self, uids: list[int]) -> None:\n        self._exo_gen.remove(uids)\n        for uid in uids:\n            self._active_tasks.pop(uid, None)\n\n    def close(self) -> None:\n        self._exo_gen.close()\n\n    def _save_prefix_cache(\n        self,\n        all_prompt_tokens: mx.array,\n        cache: KVCacheType,\n        cache_snapshots: list[CacheSnapshot] | None,\n        prefix_hit_length: int,\n        matched_index: int | None,\n    ) -> None:\n        if self.kv_prefix_cache is None:\n            return\n\n        try:\n            hit_ratio = (\n                prefix_hit_length / len(all_prompt_tokens)\n                if len(all_prompt_tokens) > 0\n                else 0.0\n            )\n            if (\n                matched_index is not None\n                and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE\n            ):\n                self.kv_prefix_cache.update_kv_cache(\n                    matched_index,\n                    all_prompt_tokens,\n                    cache,\n                    cache_snapshots,\n                    restore_pos=prefix_hit_length,\n                )\n            else:\n                self.kv_prefix_cache.add_kv_cache(\n                    all_prompt_tokens, cache, cache_snapshots\n                )\n        except Exception:\n            logger.warning(\"Failed to save prefix cache\", exc_info=True)\n"
  },
  {
    "path": "src/exo/worker/engines/mlx/generator/generate.py",
    "content": "import functools\nimport math\nimport time\nfrom copy import deepcopy\nfrom typing import Callable, Generator, cast, get_args\n\nimport mlx.core as mx\nfrom mlx_lm.generate import (\n    maybe_quantize_kv_cache,\n    stream_generate,\n)\nfrom mlx_lm.models.cache import ArraysCache, RotatingKVCache\nfrom mlx_lm.sample_utils import make_logits_processors, make_sampler\nfrom mlx_lm.tokenizer_utils import TokenizerWrapper\n\nfrom exo.api.types import (\n    CompletionTokensDetails,\n    FinishReason,\n    GenerationStats,\n    PromptTokensDetails,\n    TopLogprobItem,\n    Usage,\n)\nfrom exo.shared.types.common import ModelId\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.mlx import KVCacheType, Model\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\nfrom exo.shared.types.worker.runner_response import (\n    GenerationResponse,\n)\nfrom exo.worker.engines.mlx.auto_parallel import (\n    PipelineFirstLayer,\n    PipelineLastLayer,\n    clear_prefill_sends,\n    flush_prefill_sends,\n    set_pipeline_prefill,\n    set_pipeline_queue_sends,\n)\nfrom exo.worker.engines.mlx.cache import (\n    CacheSnapshot,\n    KVPrefixCache,\n    encode_prompt,\n    has_non_kv_caches,\n    make_kv_cache,\n    snapshot_ssm_states,\n)\nfrom exo.worker.engines.mlx.constants import (\n    DEFAULT_TOP_LOGPROBS,\n    KV_BITS,\n    KV_GROUP_SIZE,\n    MAX_TOKENS,\n)\nfrom exo.worker.engines.mlx.utils_mlx import (\n    apply_chat_template,\n    fix_unmatched_think_end_tokens,\n    mx_barrier,\n)\nfrom exo.worker.runner.bootstrap import logger\n\ngeneration_stream = mx.new_stream(mx.default_device())\n\n_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5\n\n\nclass PrefillCancelled(BaseException):\n    \"\"\"Raised when prefill is cancelled via the progress callback.\"\"\"\n\n\ndef _has_pipeline_communication_layer(model: Model):\n    for layer in model.layers:\n        if isinstance(layer, (PipelineFirstLayer, PipelineLastLayer)):\n            return True\n    return False\n\n\ndef pipeline_parallel_prefill(\n    model: Model,\n    prompt: mx.array,\n    prompt_cache: KVCacheType,\n    prefill_step_size: int,\n    kv_group_size: int | None,\n    kv_bits: int | None,\n    prompt_progress_callback: Callable[[int, int], None],\n    distributed_prompt_progress_callback: Callable[[], None] | None,\n    group: mx.distributed.Group,\n) -> None:\n    \"\"\"Prefill the KV cache for pipeline parallel with overlapping stages.\n\n    Each rank processes the full prompt through its real cache, offset by leading\n    and trailing dummy iterations.\n\n    Total iterations per rank = N_real_chunks + world_size - 1:\n      - rank r leading dummies  (skip_pipeline_io, throwaway cache)\n      - N_real_chunks real      (pipeline IO active, real cache)\n      - (world_size-1-r) trailing dummies (skip_pipeline_io, throwaway cache)\n\n    e.g.\n    Timeline (2 ranks, 3 chunks of 10240 tokens @ step=4096):\n        iter 0: R0 real[0:4096]     R1 dummy\n        iter 1: R0 real[4096:8192]  R1 real[0:4096]\n        iter 2: R0 real[8192:10240] R1 real[4096:8192]\n        iter 3: R0 dummy            R1 real[8192:10240]\n\n    This function is designed to match mlx_lm's stream_generate exactly in terms of\n    side effects (given the same prefill step size)\n    \"\"\"\n    prefill_step_size = prefill_step_size // min(4, group.size())\n\n    quantize_cache_fn: Callable[..., None] = functools.partial(\n        maybe_quantize_kv_cache,\n        quantized_kv_start=0,\n        kv_group_size=kv_group_size,\n        kv_bits=kv_bits,\n    )\n\n    _prompt_cache: KVCacheType = prompt_cache\n    rank = group.rank()\n    world_size = group.size()\n\n    # Build list of real prompt chunk sizes\n    total = len(prompt)\n    real_chunk_sizes: list[int] = []\n    remaining = total - 1\n    while remaining:\n        n = min(prefill_step_size, remaining)\n        real_chunk_sizes.append(n)\n        remaining -= n\n    n_real = len(real_chunk_sizes)\n\n    # Each rank does: [rank leading dummies] [N real chunks] [world_size-1-rank trailing dummies]\n    n_leading = rank\n    n_trailing = world_size - 1 - rank\n    n_total = n_leading + n_real + n_trailing\n\n    t_start = time.perf_counter()\n    processed = 0\n    logger.info(\n        f\"[R{rank}] Pipeline prefill: {n_real} real + {n_leading} leading + {n_trailing} trailing = {n_total} iterations\"\n    )\n    clear_prefill_sends()\n\n    # Initial callback matching generate_step\n    prompt_progress_callback(0, total)\n\n    try:\n        with mx.stream(generation_stream):\n            for _ in range(n_leading):\n                if distributed_prompt_progress_callback is not None:\n                    distributed_prompt_progress_callback()\n\n            for i in range(n_real):\n                chunk_size = real_chunk_sizes[i]\n                model(\n                    prompt[processed : processed + chunk_size][None],\n                    cache=_prompt_cache,\n                )\n                quantize_cache_fn(_prompt_cache)\n                processed += chunk_size\n\n                if distributed_prompt_progress_callback is not None:\n                    distributed_prompt_progress_callback()\n\n                flush_prefill_sends()\n\n                prompt_progress_callback(processed, total)\n\n            for _ in range(n_trailing):\n                if distributed_prompt_progress_callback is not None:\n                    distributed_prompt_progress_callback()\n\n    finally:\n        clear_prefill_sends()\n\n    # Post-loop: process remaining 1 token + add +1 entry to match stream_generate.\n    for _ in range(2):\n        with mx.stream(generation_stream):\n            model(prompt[-1:][None], cache=_prompt_cache)\n            quantize_cache_fn(_prompt_cache)\n        flush_prefill_sends()\n\n    assert _prompt_cache is not None\n    mx.eval([c.state for c in _prompt_cache])  # type: ignore\n\n    # Final callback matching generate_step\n    prompt_progress_callback(total, total)\n\n    logger.info(\n        f\"[R{rank}] Prefill: {n_real} real + {n_leading}+{n_trailing} dummy iterations, \"\n        f\"Processed {processed} tokens in {(time.perf_counter() - t_start) * 1000:.1f}ms\"\n    )\n\n\ndef prefill(\n    model: Model,\n    tokenizer: TokenizerWrapper,\n    sampler: Callable[[mx.array], mx.array],\n    prompt_tokens: mx.array,\n    cache: KVCacheType,\n    group: mx.distributed.Group | None,\n    on_prefill_progress: Callable[[int, int], None] | None,\n    distributed_prompt_progress_callback: Callable[[], None] | None,\n) -> tuple[float, int, list[CacheSnapshot]]:\n    \"\"\"Prefill the KV cache with prompt tokens.\n\n    This runs the model over the prompt tokens to populate the cache,\n    then trims off the extra generated token.\n\n    Returns:\n        (tokens_per_sec, num_tokens, snapshots)\n    \"\"\"\n    num_tokens = len(prompt_tokens)\n    if num_tokens == 0:\n        return 0.0, 0, []\n\n    logger.debug(f\"Prefilling {num_tokens} tokens...\")\n    start_time = time.perf_counter()\n    has_ssm = has_non_kv_caches(cache)\n    snapshots: list[CacheSnapshot] = []\n\n    # TODO(evan): kill the callbacks/runner refactor\n    def progress_callback(processed: int, total: int) -> None:\n        elapsed = time.perf_counter() - start_time\n        tok_per_sec = processed / elapsed if elapsed > 0 else 0\n        logger.debug(\n            f\"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)\"\n        )\n        if has_ssm:\n            snapshots.append(snapshot_ssm_states(cache))\n\n        if on_prefill_progress is not None:\n            on_prefill_progress(processed, total)\n\n    def combined_progress_callback(processed: int, total: int) -> None:\n        if distributed_prompt_progress_callback is not None:\n            distributed_prompt_progress_callback()\n        progress_callback(processed, total)\n\n    set_pipeline_prefill(model, is_prefill=True)\n\n    mx_barrier(group)\n    logger.info(\"Starting prefill\")\n\n    is_pipeline = _has_pipeline_communication_layer(model)\n\n    prefill_step_size = 4096\n\n    try:\n        if is_pipeline and num_tokens >= prefill_step_size:\n            set_pipeline_queue_sends(model, queue_sends=True)\n            assert group is not None, \"Pipeline prefill requires a distributed group\"\n            pipeline_parallel_prefill(\n                model=model,\n                prompt=prompt_tokens,\n                prompt_cache=cache,\n                prefill_step_size=prefill_step_size,\n                kv_group_size=KV_GROUP_SIZE,\n                kv_bits=KV_BITS,\n                prompt_progress_callback=progress_callback,\n                distributed_prompt_progress_callback=distributed_prompt_progress_callback,\n                group=group,\n            )\n        else:\n            # Use max_tokens=1 because max_tokens=0 does not work.\n            # We just throw away the generated token - we only care about filling the cache\n            for _ in stream_generate(\n                model=model,\n                tokenizer=tokenizer,\n                prompt=prompt_tokens,\n                max_tokens=1,\n                sampler=sampler,\n                prompt_cache=cache,\n                prefill_step_size=prefill_step_size,\n                kv_group_size=KV_GROUP_SIZE,\n                kv_bits=KV_BITS,\n                prompt_progress_callback=combined_progress_callback,\n            ):\n                break  # Stop after first iteration - cache is now filled\n    except PrefillCancelled:\n        set_pipeline_queue_sends(model, queue_sends=False)\n        set_pipeline_prefill(model, is_prefill=False)\n        raise\n\n    set_pipeline_queue_sends(model, queue_sends=False)\n    set_pipeline_prefill(model, is_prefill=False)\n\n    # stream_generate added 1 extra generated token to the cache, so we should trim it.\n    # Because of needing to roll back arrays cache, we will generate on 2 tokens so trim 1 more.\n    pre_gen = deepcopy(snapshots[-2]) if has_ssm else None\n    for i, c in enumerate(cache):\n        if has_ssm and isinstance(c, (ArraysCache, RotatingKVCache)):\n            assert pre_gen is not None\n            if pre_gen.states[i] is not None:\n                cache[i] = deepcopy(pre_gen.states[i])  # type: ignore\n        else:\n            assert not isinstance(c, (ArraysCache, RotatingKVCache))\n            c.trim(2)\n\n    elapsed = time.perf_counter() - start_time\n    tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0\n    logger.debug(\n        f\"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s \"\n        f\"({tokens_per_sec:.1f} tok/s)\"\n    )\n    # Exclude the last snapshot\n    return tokens_per_sec, num_tokens, snapshots[:-1] if snapshots else []\n\n\ndef warmup_inference(\n    model: Model,\n    tokenizer: TokenizerWrapper,\n    group: mx.distributed.Group | None,\n    model_id: ModelId,\n) -> int:\n    logger.info(f\"warming up inference for instance: {model_id}\")\n    t = time.monotonic()\n\n    content = \"Prompt to warm up the inference engine. Repeat this.\"\n\n    warmup_prompt = apply_chat_template(\n        tokenizer=tokenizer,\n        task_params=TextGenerationTaskParams(\n            model=ModelId(\"\"),\n            input=[InputMessage(role=\"user\", content=content)],\n        ),\n    )\n\n    tokens_generated = 0\n\n    cache = make_kv_cache(\n        model=model,\n    )\n\n    # Use a default sampler for warmup\n    sampler = make_sampler(temp=0.0)\n\n    mx_barrier(group)\n\n    logger.info(\"Generating warmup tokens\")\n    for _r in stream_generate(\n        model=model,\n        tokenizer=tokenizer,\n        prompt=warmup_prompt,\n        max_tokens=50,\n        sampler=sampler,\n        prompt_cache=cache,\n        prefill_step_size=2048,\n        kv_group_size=KV_GROUP_SIZE,\n        kv_bits=KV_BITS,\n    ):\n        logger.info(\"Generated warmup token: \" + str(_r.text))\n        tokens_generated += 1\n\n    logger.info(\"Generated ALL warmup tokens\")\n\n    mx_barrier(group)\n\n    logger.info(f\"warmed up by generating {tokens_generated} tokens\")\n    check_for_cancel_every = min(\n        math.ceil(tokens_generated / min(time.monotonic() - t, 0.001)), 100\n    )\n    if group is not None:\n        check_for_cancel_every = int(\n            mx.max(\n                mx.distributed.all_gather(\n                    mx.array([check_for_cancel_every]),\n                    group=group,\n                )\n            ).item()\n        )\n\n    logger.info(\n        f\"runner checking for cancellation every {check_for_cancel_every} tokens\"\n    )\n\n    return check_for_cancel_every\n\n\ndef ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]:\n    token_ids = [int(t) for t in token_ids]\n\n    def proc(_history: mx.array, logits: mx.array) -> mx.array:\n        for tid in token_ids:\n            logits[..., tid] = -1e9\n        return logits\n\n    return proc\n\n\ndef eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:\n    eos: list[int] | None = getattr(tokenizer, \"eos_token_ids\", None)\n    if eos is None:\n        return []\n    return eos\n\n\ndef extract_top_logprobs(\n    logprobs: mx.array,\n    tokenizer: TokenizerWrapper,\n    top_logprobs: int,\n    selected_token: int,\n) -> tuple[float, list[TopLogprobItem]]:\n    \"\"\"Extract the selected token's logprob and top alternative tokens.\n\n    Args:\n        logprobs: Full vocabulary logprobs array from MLX\n        tokenizer: Tokenizer for decoding token IDs to strings\n        top_logprobs: Number of top alternatives to return\n        selected_token: The token ID that was actually sampled\n\n    Returns:\n        Tuple of (selected_token_logprob, list of TopLogprobItem for top alternatives)\n    \"\"\"\n    # Get the logprob of the selected token\n    selected_logprob = float(logprobs[selected_token].item())\n\n    # Get top indices (most probable tokens)\n    # mx.argpartition gives indices that would partition the array\n    # We negate logprobs since argpartition finds smallest, and we want largest\n    top_logprobs = min(top_logprobs, logprobs.shape[0])  # Don't exceed vocab size\n    top_indices = mx.argpartition(-logprobs, top_logprobs)[:top_logprobs]\n\n    # Get the actual logprob values for these indices\n    top_values = logprobs[top_indices]\n\n    # Sort by logprob (descending) for consistent ordering\n    sort_order = mx.argsort(-top_values)\n    top_indices = top_indices[sort_order]\n    top_values = top_values[sort_order]\n\n    # Convert to list of TopLogprobItem\n    top_logprob_items: list[TopLogprobItem] = []\n    for i in range(top_logprobs):\n        token_id = int(top_indices[i].item())\n        token_logprob = float(top_values[i].item())\n        if math.isnan(token_logprob):\n            continue\n\n        # Decode token ID to string\n        token_str = tokenizer.decode([token_id])\n        # Get byte representation\n        token_bytes = list(token_str.encode(\"utf-8\"))\n        top_logprob_items.append(\n            TopLogprobItem(\n                token=token_str,\n                logprob=token_logprob,\n                bytes=token_bytes,\n            )\n        )\n\n    return selected_logprob, top_logprob_items\n\n\ndef mlx_generate(\n    model: Model,\n    tokenizer: TokenizerWrapper,\n    task: TextGenerationTaskParams,\n    prompt: str,\n    kv_prefix_cache: KVPrefixCache | None,\n    group: mx.distributed.Group | None,\n    on_prefill_progress: Callable[[int, int], None] | None = None,\n    distributed_prompt_progress_callback: Callable[[], None] | None = None,\n    on_generation_token: Callable[[], None] | None = None,\n) -> Generator[GenerationResponse]:\n    # Ensure that generation stats only contains peak memory for this generation\n    mx.reset_peak_memory()\n    # TODO: Randomise task seed and set in taskparams, instead of hard coding as 42.\n    seed = task.seed or 42\n    mx.random.seed(seed)\n\n    # Encode prompt once at the top and fix unmatched think tags\n    all_prompt_tokens = encode_prompt(tokenizer, prompt)\n    all_prompt_tokens = fix_unmatched_think_end_tokens(all_prompt_tokens, tokenizer)\n\n    # Do not use the prefix cache if we are trying to do benchmarks.\n    is_bench = task.bench\n    if is_bench:\n        kv_prefix_cache = None\n\n    # Use prefix cache if available, otherwise create fresh cache\n    prefix_hit_length = 0\n    matched_index: int | None = None\n    if kv_prefix_cache is None:\n        caches = make_kv_cache(model=model)\n        prompt_tokens = all_prompt_tokens\n    else:\n        caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(\n            model, all_prompt_tokens\n        )\n        prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)\n        if prefix_hit_length > 0:\n            logger.info(\n                f\"KV cache hit: {prefix_hit_length}/{len(all_prompt_tokens)} tokens cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)\"\n            )\n\n    logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = (\n        make_logits_processors(\n            repetition_penalty=task.repetition_penalty,\n            repetition_context_size=task.repetition_context_size,\n        )\n    )\n    if is_bench:\n        # Only sample length eos tokens\n        eos_ids = eos_ids_from_tokenizer(tokenizer)\n        logits_processors = [ban_token_ids(eos_ids)] + logits_processors\n\n    sampler = make_sampler(\n        temp=task.temperature if task.temperature is not None else 0.7,\n        top_p=task.top_p if task.top_p is not None else 1.0,\n        min_p=task.min_p if task.min_p is not None else 0.05,\n        top_k=task.top_k if task.top_k is not None else 0,\n    )\n\n    # Normalize stop sequences to a list\n    stop_sequences: list[str] = (\n        ([task.stop] if isinstance(task.stop, str) else task.stop)\n        if task.stop is not None\n        else []\n    )\n    max_stop_len = max((len(s) for s in stop_sequences), default=0)\n\n    # Prefill cache with all tokens except the last one\n    prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(\n        model,\n        tokenizer,\n        sampler,\n        prompt_tokens[:-1],\n        caches,\n        group,\n        on_prefill_progress,\n        distributed_prompt_progress_callback,\n    )\n    cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None\n\n    # stream_generate starts from the last token\n    last_token = prompt_tokens[-2:]\n\n    max_tokens = task.max_output_tokens or MAX_TOKENS\n    accumulated_text = \"\"\n    generated_text_parts: list[str] = []\n    generation_start_time = time.perf_counter()\n    usage: Usage | None = None\n    in_thinking = False\n    reasoning_tokens = 0\n    think_start = tokenizer.think_start\n    think_end = tokenizer.think_end\n\n    logger.info(\"Starting decode\")\n    mx_barrier(group)\n\n    for completion_tokens, out in enumerate(\n        stream_generate(\n            model=model,\n            tokenizer=tokenizer,\n            prompt=last_token,\n            max_tokens=max_tokens,\n            sampler=sampler,\n            logits_processors=logits_processors,\n            prompt_cache=caches,\n            prefill_step_size=1,\n            kv_group_size=KV_GROUP_SIZE,\n            kv_bits=KV_BITS,\n        ),\n        start=1,\n    ):\n        generated_text_parts.append(out.text)\n        accumulated_text += out.text\n\n        if think_start is not None and out.text == think_start:\n            in_thinking = True\n        elif think_end is not None and out.text == think_end:\n            in_thinking = False\n        if in_thinking:\n            reasoning_tokens += 1\n\n        # Check for stop sequences\n        text = out.text\n        finish_reason: FinishReason | None = cast(\n            FinishReason | None, out.finish_reason\n        )\n        stop_matched = False\n\n        if stop_sequences:\n            for stop_seq in stop_sequences:\n                if stop_seq in accumulated_text:\n                    # Trim text to just before the stop sequence\n                    stop_index = accumulated_text.find(stop_seq)\n                    text_before_stop = accumulated_text[:stop_index]\n                    chunk_start = len(accumulated_text) - len(out.text)\n                    text = text_before_stop[chunk_start:]\n                    finish_reason = \"stop\"\n                    stop_matched = True\n                    break\n\n        is_done = finish_reason is not None\n\n        stats: GenerationStats | None = None\n        if is_done:\n            stats = GenerationStats(\n                prompt_tps=float(prefill_tps or out.prompt_tps),\n                generation_tps=float(out.generation_tps),\n                prompt_tokens=int(prefill_tokens + out.prompt_tokens),\n                generation_tokens=int(out.generation_tokens),\n                peak_memory_usage=Memory.from_gb(out.peak_memory),\n            )\n            if not stop_matched and out.finish_reason not in get_args(FinishReason):\n                logger.warning(\n                    f\"Model generated unexpected finish_reason: {out.finish_reason}\"\n                )\n\n            total_prompt_tokens = len(all_prompt_tokens)\n            usage = Usage(\n                prompt_tokens=total_prompt_tokens,\n                completion_tokens=completion_tokens,\n                total_tokens=total_prompt_tokens + completion_tokens,\n                prompt_tokens_details=PromptTokensDetails(\n                    cached_tokens=prefix_hit_length\n                ),\n                completion_tokens_details=CompletionTokensDetails(\n                    reasoning_tokens=reasoning_tokens\n                ),\n            )\n\n        # Extract logprobs from the full vocabulary logprobs array\n        logprob: float | None = None\n        top_logprobs: list[TopLogprobItem] | None = None\n        if task.logprobs:\n            logprob, top_logprobs = extract_top_logprobs(\n                logprobs=out.logprobs,\n                tokenizer=tokenizer,\n                top_logprobs=task.top_logprobs or DEFAULT_TOP_LOGPROBS,\n                selected_token=out.token,\n            )\n\n        if is_done:\n            # Log generation stats\n            generation_elapsed = time.perf_counter() - generation_start_time\n            generated_tokens = len(generated_text_parts)\n            generation_tps = (\n                generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0\n            )\n            logger.debug(\n                f\"Generation complete: prefill {prompt_tokens} tokens @ \"\n                f\"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ \"\n                f\"{generation_tps:.1f} tok/s\"\n            )\n            if kv_prefix_cache is not None:\n                generated_tokens_array = mx.array(\n                    tokenizer.encode(\n                        \"\".join(generated_text_parts), add_special_tokens=False\n                    )\n                )\n                full_prompt_tokens = mx.concatenate(\n                    [all_prompt_tokens, generated_tokens_array]\n                )\n                hit_ratio = (\n                    prefix_hit_length / len(all_prompt_tokens)\n                    if len(all_prompt_tokens) > 0\n                    else 0.0\n                )\n                if (\n                    matched_index is not None\n                    and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE\n                ):\n                    kv_prefix_cache.update_kv_cache(\n                        matched_index,\n                        full_prompt_tokens,\n                        caches,\n                        cache_snapshots,\n                        restore_pos=prefix_hit_length,\n                    )\n                else:\n                    kv_prefix_cache.add_kv_cache(\n                        full_prompt_tokens, caches, cache_snapshots\n                    )\n\n        if on_generation_token is not None:\n            on_generation_token()\n\n        yield GenerationResponse(\n            text=text,\n            token=out.token,\n            logprob=logprob,\n            top_logprobs=top_logprobs,\n            finish_reason=finish_reason,\n            stats=stats,\n            usage=usage,\n        )\n\n        if is_done:\n            mx_barrier(group)\n            break\n\n        # Limit accumulated_text to what's needed for stop sequence detection\n        if max_stop_len > 0 and len(accumulated_text) > max_stop_len:\n            accumulated_text = accumulated_text[-max_stop_len:]\n"
  },
  {
    "path": "src/exo/worker/engines/mlx/tests/test_batch_generate.py",
    "content": "# pyright: reportAny=false, reportUnknownVariableType=false\n# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false\n# pyright: reportUnknownLambdaType=false, reportPrivateUsage=false\n# pyright: reportInvalidCast=false, reportArgumentType=false\n# pyright: reportUnusedImport=false\n\"\"\"Test B=1 vs B=2 equivalence for batch generation.\n\nVerifies that running two requests concurrently in a batch (B=2) produces\nidentical token selections to running them sequentially (B=1).\nUses random weights — no model download required.\n\"\"\"\n\nfrom pathlib import Path\nfrom typing import cast\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport mlx.utils\nimport pytest\nfrom mlx_lm.generate import _merge_caches\nfrom mlx_lm.sample_utils import make_sampler\nfrom mlx_lm.tokenizer_utils import TokenizerWrapper\nfrom transformers import AutoTokenizer\n\n# Import batch_generate to activate the right-padding BatchKVCache patch\nimport exo.worker.engines.mlx.generator.batch_generate  # noqa: F401\nfrom exo.shared.types.mlx import Model\nfrom exo.worker.engines.mlx.cache import encode_prompt, make_kv_cache\nfrom exo.worker.engines.mlx.generator.generate import prefill\n\nNUM_STEPS = 20\n\n\ndef _init_random(model: nn.Module) -> None:\n    \"\"\"Initialize all model parameters with random values.\"\"\"\n    params = model.parameters()\n    new_params = mlx.utils.tree_map(\n        lambda p: mx.random.normal(shape=p.shape, dtype=p.dtype)\n        if isinstance(p, mx.array)\n        else p,\n        params,\n    )\n    model.update(new_params)\n    mx.eval(model.parameters())\n\n\ndef _run_b1_vs_b2(\n    model: Model,\n    tokenizer: TokenizerWrapper,\n    tokens_a: mx.array,\n    tokens_b: mx.array,\n) -> tuple[float, int]:\n    \"\"\"Run B=1 sequential and B=2 batched, return (max_diff, mismatches).\"\"\"\n    sampler = make_sampler(temp=0.0)\n\n    # B=1 sequential\n    cache_a1 = make_kv_cache(model)\n    prefill(model, tokenizer, sampler, tokens_a[:-1], cache_a1, None, None, None)\n    merged_a1 = _merge_caches([[c for c in cache_a1]])\n    for c in merged_a1:\n        c.prepare(lengths=[1], right_padding=[0])\n    model(mx.array([[tokens_a[-2].item()]]), cache=merged_a1)\n    mx.eval([c.state for c in merged_a1])\n    for c in merged_a1:\n        c.finalize()\n\n    cache_b1 = make_kv_cache(model)\n    prefill(model, tokenizer, sampler, tokens_b[:-1], cache_b1, None, None, None)\n    merged_b1 = _merge_caches([[c for c in cache_b1]])\n    for c in merged_b1:\n        c.prepare(lengths=[1], right_padding=[0])\n    model(mx.array([[tokens_b[-2].item()]]), cache=merged_b1)\n    mx.eval([c.state for c in merged_b1])\n    for c in merged_b1:\n        c.finalize()\n\n    b1_logits_a: list[mx.array] = []\n    b1_logits_b: list[mx.array] = []\n    next_a, next_b = tokens_a[-1].item(), tokens_b[-1].item()\n    for _ in range(NUM_STEPS):\n        la = model(mx.array([[next_a]]), cache=merged_a1)\n        mx.eval(la)\n        b1_logits_a.append(la[0, -1])\n        next_a = int(mx.argmax(la[0, -1]).item())\n        lb = model(mx.array([[next_b]]), cache=merged_b1)\n        mx.eval(lb)\n        b1_logits_b.append(lb[0, -1])\n        next_b = int(mx.argmax(lb[0, -1]).item())\n\n    # B=2 batched\n    cache_a2 = make_kv_cache(model)\n    cache_b2 = make_kv_cache(model)\n    prefill(model, tokenizer, sampler, tokens_a[:-1], cache_a2, None, None, None)\n    prefill(model, tokenizer, sampler, tokens_b[:-1], cache_b2, None, None, None)\n    merged_b2 = _merge_caches([list(cache_a2), list(cache_b2)])\n    for c in merged_b2:\n        c.prepare(lengths=[1, 1], right_padding=[0, 0])\n    model(\n        mx.array([[tokens_a[-2].item()], [tokens_b[-2].item()]]),\n        cache=merged_b2,\n    )\n    mx.eval([c.state for c in merged_b2])\n    for c in merged_b2:\n        c.finalize()\n\n    b2_logits_a: list[mx.array] = []\n    b2_logits_b: list[mx.array] = []\n    next_a2, next_b2 = tokens_a[-1].item(), tokens_b[-1].item()\n    for _ in range(NUM_STEPS):\n        l2 = model(mx.array([[next_a2], [next_b2]]), cache=merged_b2)\n        mx.eval(l2)\n        b2_logits_a.append(l2[0, -1])\n        b2_logits_b.append(l2[1, -1])\n        next_a2 = int(mx.argmax(l2[0, -1]).item())\n        next_b2 = int(mx.argmax(l2[1, -1]).item())\n\n    # Compare\n    max_diff = 0.0\n    mismatches = 0\n    for step in range(NUM_STEPS):\n        diff_a = float(\n            mx.max(\n                mx.abs(\n                    b1_logits_a[step].astype(mx.float32)\n                    - b2_logits_a[step].astype(mx.float32)\n                )\n            ).item()\n        )\n        diff_b = float(\n            mx.max(\n                mx.abs(\n                    b1_logits_b[step].astype(mx.float32)\n                    - b2_logits_b[step].astype(mx.float32)\n                )\n            ).item()\n        )\n        max_diff = max(max_diff, diff_a, diff_b)\n        if int(mx.argmax(b1_logits_a[step]).item()) != int(\n            mx.argmax(b2_logits_a[step]).item()\n        ):\n            mismatches += 1\n        if int(mx.argmax(b1_logits_b[step]).item()) != int(\n            mx.argmax(b2_logits_b[step]).item()\n        ):\n            mismatches += 1\n\n    return max_diff, mismatches\n\n\ndef _make_tokenizer() -> TokenizerWrapper:\n    \"\"\"Load the Qwen tokenizer (tiny download, shared across Qwen models).\"\"\"\n    from huggingface_hub import snapshot_download\n\n    model_path = Path(\n        snapshot_download(\n            \"mlx-community/Qwen3.5-35B-A3B-4bit\",\n            allow_patterns=[\"tokenizer*\", \"*.jinja\"],\n        )\n    )\n    hf_tokenizer = AutoTokenizer.from_pretrained(model_path)\n    return TokenizerWrapper(hf_tokenizer)\n\n\n@pytest.mark.slow\ndef test_batch_b2_llama() -> None:\n    \"\"\"Llama-style model (KVCache only) must produce bit-exact logits in B=2.\n\n    Right-padded BatchKVCache keeps data at position 0 for all sequences,\n    so flash attention sees identical data layout as B=1 → bit-exact output.\n    \"\"\"\n    from mlx_lm.models.llama import Model as LlamaModel\n    from mlx_lm.models.llama import ModelArgs\n\n    mx.random.seed(42)\n    args = ModelArgs(\n        model_type=\"llama\",\n        hidden_size=256,\n        num_hidden_layers=4,\n        intermediate_size=512,\n        num_attention_heads=4,\n        num_key_value_heads=2,\n        rms_norm_eps=1e-6,\n        vocab_size=248320,\n        rope_theta=10000.0,\n        tie_word_embeddings=True,\n    )\n    model = LlamaModel(args)\n    _init_random(model)\n\n    tokenizer = _make_tokenizer()\n    tokens_a = encode_prompt(tokenizer, \"Write a short essay about AI.\")\n    tokens_b = encode_prompt(tokenizer, \"Explain evolution briefly.\")\n\n    max_diff, mismatches = _run_b1_vs_b2(\n        cast(Model, model), tokenizer, tokens_a, tokens_b\n    )\n    assert mismatches == 0, f\"Llama B=2 token mismatches: {mismatches}/{NUM_STEPS * 2}\"\n    assert max_diff < 0.002, f\"Llama B=2 max logit diff: {max_diff}\"\n\n\n@pytest.mark.slow\ndef test_batch_b2_qwen35_moe() -> None:\n    \"\"\"Qwen3.5 MoE model (hybrid SSM+attention+MoE) must produce bit-exact logits in B=2.\n\n    Right-padded BatchKVCache keeps data at position 0 for all sequences,\n    so flash attention sees identical data layout as B=1 → bit-exact output.\n    \"\"\"\n    from mlx_lm.models.qwen3_5_moe import Model as Qwen35MoeModel\n    from mlx_lm.models.qwen3_5_moe import ModelArgs\n\n    mx.random.seed(42)\n    config = {\n        \"model_type\": \"qwen3_5_moe\",\n        \"text_config\": {\n            \"model_type\": \"qwen3_5_moe_text\",\n            \"hidden_size\": 256,\n            \"num_hidden_layers\": 8,\n            \"intermediate_size\": 512,\n            \"num_attention_heads\": 4,\n            \"num_key_value_heads\": 2,\n            \"rms_norm_eps\": 1e-6,\n            \"vocab_size\": 248320,\n            \"head_dim\": 64,\n            \"max_position_embeddings\": 4096,\n            \"full_attention_interval\": 4,\n            \"layer_types\": [\n                \"linear_attention\",\n                \"linear_attention\",\n                \"linear_attention\",\n                \"full_attention\",\n                \"linear_attention\",\n                \"linear_attention\",\n                \"linear_attention\",\n                \"full_attention\",\n            ],\n            \"linear_conv_kernel_dim\": 4,\n            \"linear_key_head_dim\": 64,\n            \"linear_num_key_heads\": 4,\n            \"linear_num_value_heads\": 4,\n            \"linear_value_head_dim\": 64,\n            \"mamba_ssm_dtype\": \"float32\",\n            \"num_experts\": 8,\n            \"num_experts_per_tok\": 2,\n            \"moe_intermediate_size\": 256,\n            \"shared_expert_intermediate_size\": 256,\n            \"rope_parameters\": {\n                \"rope_type\": \"default\",\n                \"rope_theta\": 10000000,\n            },\n            \"attention_bias\": False,\n            \"attn_output_gate\": True,\n        },\n    }\n    args = ModelArgs.from_dict(config)\n    model = Qwen35MoeModel(args)\n    _init_random(model)\n\n    tokenizer = _make_tokenizer()\n    tokens_a = encode_prompt(tokenizer, \"Write a short essay about AI.\")\n    tokens_b = encode_prompt(tokenizer, \"Explain evolution briefly.\")\n\n    max_diff, mismatches = _run_b1_vs_b2(\n        cast(Model, model), tokenizer, tokens_a, tokens_b\n    )\n    assert mismatches == 0, (\n        f\"Qwen3.5 MoE B=2 token mismatches: {mismatches}/{NUM_STEPS * 2}\"\n    )\n    assert max_diff < 0.002, f\"Qwen3.5 MoE B=2 max logit diff: {max_diff}\"\n"
  },
  {
    "path": "src/exo/worker/engines/mlx/utils_mlx.py",
    "content": "import json\nimport os\nimport re\nimport sys\nimport tempfile\nimport time\nfrom pathlib import Path\nfrom typing import Any, cast\n\n# Monkey-patch for transformers 5.x compatibility\n# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location\n# which was moved in transformers 5.0.0rc2\ntry:\n    import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization\n    from transformers.convert_slow_tokenizer import bytes_to_unicode\n\n    if not hasattr(gpt2_tokenization, \"bytes_to_unicode\"):\n        gpt2_tokenization.bytes_to_unicode = bytes_to_unicode  # type: ignore[attr-defined]\nexcept ImportError:\n    pass  # transformers < 5.0 or bytes_to_unicode not available\n\nfrom mlx_lm.models.cache import KVCache\nfrom mlx_lm.models.deepseek_v3 import DeepseekV3Model\nfrom mlx_lm.tokenizer_utils import TokenizerWrapper\n\nfrom exo.shared.models.model_cards import ModelId\nfrom exo.worker.engines.mlx.constants import TRUST_REMOTE_CODE\n\ntry:\n    from mlx_lm.tokenizer_utils import load_tokenizer\nexcept ImportError:\n    from mlx_lm.tokenizer_utils import load as load_tokenizer\nimport contextlib\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom mlx_lm.utils import load_model\nfrom pydantic import RootModel\n\nfrom exo.download.download_utils import build_model_path\nfrom exo.shared.types.common import Host\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.mlx import Model\nfrom exo.shared.types.tasks import TaskId, TextGeneration\nfrom exo.shared.types.text_generation import TextGenerationTaskParams\nfrom exo.shared.types.worker.instances import (\n    BoundInstance,\n    MlxJacclInstance,\n    MlxRingInstance,\n)\nfrom exo.shared.types.worker.shards import (\n    CfgShardMetadata,\n    PipelineShardMetadata,\n    ShardMetadata,\n    TensorShardMetadata,\n)\nfrom exo.worker.engines.mlx.auto_parallel import (\n    LayerLoadedCallback,\n    TimeoutCallback,\n    eval_with_timeout,\n    get_inner_model,\n    get_layers,\n    pipeline_auto_parallel,\n    tensor_auto_parallel,\n)\nfrom exo.worker.runner.bootstrap import logger\n\nGroup = mx.distributed.Group\n\n\ndef get_weights_size(model_shard_meta: ShardMetadata) -> Memory:\n    return Memory.from_float_kb(\n        (model_shard_meta.end_layer - model_shard_meta.start_layer)\n        / model_shard_meta.n_layers\n        * model_shard_meta.model_card.storage_size.in_kb\n        / (\n            1\n            if isinstance(model_shard_meta, PipelineShardMetadata)\n            else model_shard_meta.world_size\n        )\n    )\n\n\nclass ModelLoadingTimeoutError(Exception):\n    pass\n\n\nclass HostList(RootModel[list[str]]):\n    @classmethod\n    def from_hosts(cls, hosts: list[Host]) -> \"HostList\":\n        return cls(root=[str(host) for host in hosts])\n\n\ndef mlx_distributed_init(\n    bound_instance: BoundInstance,\n) -> Group:\n    \"\"\"\n    Initialize MLX distributed.\n    \"\"\"\n    rank = bound_instance.bound_shard.device_rank\n    logger.info(f\"Starting initialization for rank {rank}\")\n\n    with tempfile.TemporaryDirectory() as tmpdir:\n        coordination_file = str(\n            Path(tmpdir) / f\"hosts_{bound_instance.instance.instance_id}_{rank}.json\"\n        )\n        # TODO: singleton instances\n        match bound_instance.instance:\n            case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):\n                hosts_for_node = hosts_by_node[bound_instance.bound_node_id]\n                hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()\n\n                with open(coordination_file, \"w\") as f:\n                    _ = f.write(hosts_json)\n\n                logger.info(\n                    f\"rank {rank} hostfile: {coordination_file} hosts: {hosts_json}\"\n                )\n\n                os.environ[\"MLX_HOSTFILE\"] = coordination_file\n                os.environ[\"MLX_RANK\"] = str(rank)\n                os.environ[\"MLX_RING_VERBOSE\"] = \"1\"\n                group = mx.distributed.init(backend=\"ring\", strict=True)\n\n            case MlxJacclInstance(\n                jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators\n            ):\n                assert all(\n                    jaccl_devices[i][i] is None for i in range(len(jaccl_devices))\n                )\n                # Use RDMA connectivity matrix\n                jaccl_devices_json = json.dumps(jaccl_devices)\n\n                with open(coordination_file, \"w\") as f:\n                    _ = f.write(jaccl_devices_json)\n\n                jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]\n\n                logger.info(\n                    f\"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}\"\n                )\n                logger.info(f\"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}\")\n                os.environ[\"MLX_IBV_DEVICES\"] = coordination_file\n                os.environ[\"MLX_RANK\"] = str(rank)\n                os.environ[\"MLX_JACCL_COORDINATOR\"] = jaccl_coordinator\n                group = mx.distributed.init(backend=\"jaccl\", strict=True)\n\n        logger.info(f\"Rank {rank} mlx distributed initialization complete\")\n\n        return group\n\n\ndef initialize_mlx(\n    bound_instance: BoundInstance,\n) -> Group:\n    # should we unseed it?\n    # TODO: pass in seed from params\n    mx.random.seed(42)\n\n    assert len(bound_instance.instance.shard_assignments.node_to_runner) > 1, (\n        \"Tried to initialize mlx for a single node instance\"\n    )\n    return mlx_distributed_init(bound_instance)\n\n\ndef load_mlx_items(\n    bound_instance: BoundInstance,\n    group: Group | None,\n    on_timeout: TimeoutCallback | None,\n    on_layer_loaded: LayerLoadedCallback | None,\n) -> tuple[Model, TokenizerWrapper]:\n    if group is None:\n        logger.info(f\"Single device used for {bound_instance.instance}\")\n        model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)\n        start_time = time.perf_counter()\n        model, _ = load_model(model_path, lazy=True, strict=False)\n        # Eval layers one by one for progress reporting\n        try:\n            inner = get_inner_model(model)\n            layers = get_layers(inner)\n            total = len(layers)\n            for i, layer in enumerate(layers):\n                mx.eval(layer)  # type: ignore\n                if on_layer_loaded is not None:\n                    on_layer_loaded(i, total)\n        except ValueError as e:\n            logger.opt(exception=e).debug(\n                \"Model architecture doesn't support layer-by-layer progress tracking\",\n            )\n        mx.eval(model)\n        end_time = time.perf_counter()\n        logger.info(f\"Time taken to load model: {(end_time - start_time):.2f}s\")\n        tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)\n\n    else:\n        logger.info(\"Starting distributed init\")\n        start_time = time.perf_counter()\n        model, tokenizer = shard_and_load(\n            bound_instance.bound_shard,\n            group=group,\n            on_timeout=on_timeout,\n            on_layer_loaded=on_layer_loaded,\n        )\n        end_time = time.perf_counter()\n        logger.info(\n            f\"Time taken to shard and load model: {(end_time - start_time):.2f}s\"\n        )\n\n    set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))\n\n    mx.clear_cache()\n\n    return cast(Model, model), tokenizer\n\n\ndef shard_and_load(\n    shard_metadata: ShardMetadata,\n    group: Group,\n    on_timeout: TimeoutCallback | None,\n    on_layer_loaded: LayerLoadedCallback | None,\n) -> tuple[nn.Module, TokenizerWrapper]:\n    model_path = build_model_path(shard_metadata.model_card.model_id)\n\n    model, _ = load_model(model_path, lazy=True, strict=False)\n    logger.debug(model)\n    if hasattr(model, \"model\") and isinstance(model.model, DeepseekV3Model):  # type: ignore\n        pass\n        # TODO: See if we should quantize the model.\n        # def is_attention_layer(path: str) -> bool:\n        #     path = path.lower()\n\n        #     return \"self_attn\" in path and \"layernorm\" not in path\n\n        # def quant_predicate(path: str, module: nn.Module):\n        #     if not isinstance(module, nn.Linear):\n        #         return False\n\n        #     return is_attention_layer(path)\n        # model, config = quantize_model(\n        #        model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE\n        #    )\n\n    assert isinstance(model, nn.Module)\n\n    tokenizer = get_tokenizer(model_path, shard_metadata)\n\n    logger.info(f\"Group size: {group.size()}, group rank: {group.rank()}\")\n\n    # Estimate timeout based on model size (5x default for large queued workloads)\n    base_timeout = float(os.environ.get(\"EXO_MODEL_LOAD_TIMEOUT\", \"300\"))\n    model_size = get_weights_size(shard_metadata)\n    timeout_seconds = base_timeout + model_size.in_gb\n    logger.info(\n        f\"Evaluating model parameters with timeout of {timeout_seconds:.0f}s \"\n        f\"(model size: {model_size.in_gb:.1f}GB)\"\n    )\n\n    match shard_metadata:\n        case TensorShardMetadata():\n            logger.info(f\"loading model from {model_path} with tensor parallelism\")\n            model = tensor_auto_parallel(\n                model, group, timeout_seconds, on_timeout, on_layer_loaded\n            )\n        case PipelineShardMetadata():\n            logger.info(f\"loading model from {model_path} with pipeline parallelism\")\n            model = pipeline_auto_parallel(\n                model, group, shard_metadata, on_layer_loaded=on_layer_loaded\n            )\n            eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)\n        case CfgShardMetadata():\n            raise ValueError(\n                \"CfgShardMetadata is not supported for text model loading - \"\n                \"this metadata type is only for image generation models\"\n            )\n\n    # TODO: Do we need this?\n    mx.eval(model)\n\n    logger.debug(\"SHARDED\")\n    logger.debug(model)\n\n    # Synchronize processes before generation to avoid timeout\n    mx_barrier(group)\n\n    return model, tokenizer\n\n\ndef get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:\n    \"\"\"Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id.\"\"\"\n    return load_tokenizer_for_model_id(\n        shard_metadata.model_card.model_id,\n        model_path,\n        trust_remote_code=shard_metadata.model_card.trust_remote_code,\n    )\n\n\ndef get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:\n    \"\"\"\n    Get the EOS token IDs for a model based on its ID.\n\n    Some models require explicit EOS token configuration that isn't in their\n    tokenizer config. This function returns the known EOS token IDs for such models.\n\n    Args:\n        model_id: The HuggingFace model ID\n\n    Returns:\n        List of EOS token IDs, or None if the model uses standard tokenizer config\n    \"\"\"\n    model_id_lower = model_id.lower()\n    if \"kimi-k2\" in model_id_lower:\n        return [163586]\n    elif \"glm-5\" in model_id_lower or \"glm-4.7\" in model_id_lower:\n        # For GLM-5 and GLM-4.7\n        # 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>\n        return [154820, 154827, 154829]\n    elif \"glm\" in model_id_lower:\n        # For GLM-4.5 and older\n        return [151336, 151329, 151338]\n    elif \"gpt-oss\" in model_id_lower:\n        return [200002, 200012]\n    elif \"qwen3.5\" in model_id_lower or \"qwen-3.5\" in model_id_lower:\n        # For Qwen3.5: 248046 (<|im_end|>), 248044 (<|endoftext|>)\n        return [248046, 248044]\n    return None\n\n\ndef load_tokenizer_for_model_id(\n    model_id: ModelId, model_path: Path, *, trust_remote_code: bool = TRUST_REMOTE_CODE\n) -> TokenizerWrapper:\n    \"\"\"\n    Load tokenizer for a model given its ID and local path.\n\n    This is the core tokenizer loading logic, handling special cases for different\n    model families (Kimi, GLM, etc.) and transformers 5.x compatibility.\n\n    Args:\n        model_id: The HuggingFace model ID (e.g., \"moonshotai/Kimi-K2-Instruct\")\n        model_path: Local path where the model/tokenizer files are stored\n\n    Returns:\n        TokenizerWrapper instance configured for the model\n    \"\"\"\n    model_id_lower = model_id.lower()\n    eos_token_ids = get_eos_token_ids_for_model(model_id)\n\n    # Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer\n    if \"kimi-k2\" in model_id_lower:\n        import importlib.util\n        import types\n\n        sys.path.insert(0, str(model_path))\n\n        # Load tool_declaration_ts first (tokenization_kimi imports it with relative import)\n        tool_decl_path = model_path / \"tool_declaration_ts.py\"\n        if tool_decl_path.exists():\n            spec = importlib.util.spec_from_file_location(\n                \"tool_declaration_ts\", tool_decl_path\n            )\n            if spec and spec.loader:\n                tool_decl_module = importlib.util.module_from_spec(spec)\n                sys.modules[\"tool_declaration_ts\"] = tool_decl_module\n                spec.loader.exec_module(tool_decl_module)\n\n        # Load tokenization_kimi with patched source (convert relative to absolute import)\n        tok_path = model_path / \"tokenization_kimi.py\"\n        source = tok_path.read_text()\n        source = source.replace(\"from .tool_declaration_ts\", \"from tool_declaration_ts\")\n        spec = importlib.util.spec_from_file_location(\"tokenization_kimi\", tok_path)\n        if spec:\n            tok_module = types.ModuleType(\"tokenization_kimi\")\n            tok_module.__file__ = str(tok_path)\n            sys.modules[\"tokenization_kimi\"] = tok_module\n            exec(compile(source, tok_path, \"exec\"), tok_module.__dict__)  # noqa: S102\n            TikTokenTokenizer = tok_module.TikTokenTokenizer  # type: ignore[attr-defined]  # noqa: N806\n        else:\n            from tokenization_kimi import TikTokenTokenizer  # type: ignore[import-not-found]  # noqa: I001\n\n        hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path)  # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]\n\n        # Patch encode to use internal tiktoken model directly\n        # transformers 5.x has a bug in the encode->pad path for slow tokenizers\n        def _patched_encode(text: str, **_kwargs: object) -> list[int]:\n            # Pass allowed_special=\"all\" to handle special tokens like <|im_user|>\n            return list(hf_tokenizer.model.encode(text, allowed_special=\"all\"))  # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]\n\n        hf_tokenizer.encode = _patched_encode\n        return TokenizerWrapper(\n            hf_tokenizer,\n            eos_token_ids=eos_token_ids,\n            tool_call_start=\"<|tool_calls_section_begin|>\",\n            tool_call_end=\"<|tool_calls_section_end|>\",\n            tool_parser=_parse_kimi_tool_calls,\n        )\n\n    tokenizer = load_tokenizer(\n        model_path,\n        tokenizer_config_extra={\"trust_remote_code\": trust_remote_code},\n        eos_token_ids=eos_token_ids,\n    )\n\n    if \"gemma-3\" in model_id_lower:\n        gemma_3_eos_id = 1\n        gemma_3_end_of_turn_id = 106\n        if tokenizer.eos_token_ids is not None:\n            if gemma_3_end_of_turn_id not in tokenizer.eos_token_ids:\n                tokenizer.eos_token_ids = list(tokenizer.eos_token_ids) + [\n                    gemma_3_end_of_turn_id\n                ]\n        else:\n            tokenizer.eos_token_ids = [gemma_3_eos_id, gemma_3_end_of_turn_id]\n\n    return tokenizer\n\n\ndef _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:\n    \"\"\"Normalize tool_calls in a message dict.\n\n    OpenAI format has tool_calls[].function.arguments as a JSON string,\n    but some chat templates (e.g., GLM) expect it as a dict.\n    \"\"\"\n    tool_calls = msg_dict.get(\"tool_calls\")\n    if not tool_calls or not isinstance(tool_calls, list):\n        return\n\n    for tc in tool_calls:  # pyright: ignore[reportUnknownVariableType]\n        if not isinstance(tc, dict):\n            continue\n        func = tc.get(\"function\")  # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]\n        if not isinstance(func, dict):\n            continue\n        args = func.get(\"arguments\")  # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]\n        if isinstance(args, str):\n            with contextlib.suppress(json.JSONDecodeError):\n                func[\"arguments\"] = json.loads(args)\n\n\ndef _collect_nested_property_names(schema: dict[str, Any]) -> set[str]:\n    names: set[str] = set()\n    properties: dict[str, Any] = schema.get(\"properties\", {})  # type: ignore[reportAny]\n    for prop_spec in properties.values():  # pyright: ignore[reportAny]\n        if not isinstance(prop_spec, dict):\n            continue\n        if prop_spec.get(\"type\") == \"array\":  # type: ignore[reportAny]\n            items: dict[str, Any] | None = prop_spec.get(\"items\")  # type: ignore[reportAny]\n            if isinstance(items, dict) and items.get(\"type\") == \"object\":  # type: ignore[reportAny]\n                inner_props: dict[str, Any] = items.get(\"properties\", {})  # type: ignore[reportAny]\n                for k in inner_props:  # pyright: ignore[reportUnknownVariableType]\n                    names.add(str(k))  # pyright: ignore[reportUnknownArgumentType]\n                names.update(_collect_nested_property_names(items))  # pyright: ignore[reportUnknownArgumentType]\n    return names\n\n\ndef _schemas_lost_in_prompt(prompt: str, tools: list[dict[str, Any]]) -> bool:\n    \"\"\"Return True if nested property names from any tool schema are absent.\"\"\"\n    for tool in tools:\n        fn: dict[str, Any] = tool.get(\"function\", {})  # type: ignore\n        params: dict[str, Any] = fn.get(\"parameters\", {})  # type: ignore\n        nested = _collect_nested_property_names(params)\n        if nested and not all(name in prompt for name in nested):\n            return True\n    return False\n\n\n_LOSSY_TEMPLATE_PATTERN = re.compile(\n    r\"\"\"inner_type\\s*==\\s*[\"']object \\| object[\"']\\s*or\\s*inner_type\\|length\\s*>\\s*\\d+\"\"\",\n)\n\n\ndef _patch_lossy_chat_template(template: str) -> str | None:\n    \"\"\"Patch chat templates that collapse nested object schemas to ``any[]``.\n\n    Some templates (e.g., GPT-OSS) have a guard like::\n\n        inner_type == \"object | object\" or inner_type|length > 50\n\n    The length check silently drops complex array-of-object schemas.\n    We remove the length guard, keeping only the object-union check.\n    Returns the patched template, or *None* if no patch was needed.\n    \"\"\"\n    patched, n = _LOSSY_TEMPLATE_PATTERN.subn(\n        lambda m: m.group(0).split(\" or \")[0],  # keep only the object-union check\n        template,\n    )\n    return patched if n > 0 else None\n\n\ndef _needs_dsml_encoding(task_params: TextGenerationTaskParams) -> bool:\n    if \"deepseek-v3.2\" not in task_params.model.lower():\n        return False\n    # Use DSML encoding when tools are provided or tool results are in the conversation\n    if task_params.tools:\n        return True\n    if task_params.chat_template_messages:\n        return any(\n            msg.get(\"role\") == \"tool\" for msg in task_params.chat_template_messages\n        )\n    return False\n\n\ndef apply_chat_template(\n    tokenizer: TokenizerWrapper,\n    task_params: TextGenerationTaskParams,\n) -> str:\n    \"\"\"Convert TextGenerationTaskParams to a chat template prompt.\n\n    Converts the internal format (input + instructions) to a messages list\n    that can be processed by the tokenizer's chat template.\n\n    When chat_template_messages is available (from Chat Completions API),\n    uses those directly to preserve tool_calls, thinking, and other fields.\n    \"\"\"\n    formatted_messages: list[dict[str, Any]] = []\n    if task_params.chat_template_messages is not None:\n        # Use pre-formatted messages that preserve tool_calls, thinking, etc.\n        formatted_messages = list(task_params.chat_template_messages)\n        for msg in formatted_messages:\n            _normalize_tool_calls(msg)\n    else:\n        # Add system message (instructions) if present\n        if task_params.instructions:\n            formatted_messages.append(\n                {\"role\": \"system\", \"content\": task_params.instructions}\n            )\n\n        # Convert input to messages\n        for msg in task_params.input:\n            if not msg.content:\n                logger.warning(\"Received message with empty content, skipping\")\n                continue\n            formatted_messages.append({\"role\": msg.role, \"content\": msg.content})\n\n    # For assistant prefilling, append content after templating to avoid a closing turn token.\n    partial_assistant_content: str | None = None\n    if formatted_messages and formatted_messages[-1].get(\"role\") == \"assistant\":\n        partial_assistant_content = cast(str, formatted_messages[-1].get(\"content\", \"\"))\n        formatted_messages = formatted_messages[:-1]\n\n    if _needs_dsml_encoding(task_params):\n        from exo.worker.engines.mlx.dsml_encoding import encode_messages\n\n        prompt = encode_messages(\n            messages=formatted_messages,\n            thinking_mode=\"thinking\" if task_params.enable_thinking else \"chat\",\n            tools=task_params.tools,\n        )\n        if partial_assistant_content:\n            prompt += partial_assistant_content\n        logger.info(prompt)\n        return prompt\n\n    extra_kwargs: dict[str, Any] = {}\n    if task_params.enable_thinking is not None:\n        # Qwen3 and GLM use \"enable_thinking\"; DeepSeek uses \"thinking\".\n        # Jinja ignores unknown variables, so passing both is safe.\n        extra_kwargs[\"enable_thinking\"] = task_params.enable_thinking\n        extra_kwargs[\"thinking\"] = task_params.enable_thinking\n    if task_params.reasoning_effort is not None:\n        extra_kwargs[\"reasoning_effort\"] = task_params.reasoning_effort\n\n    patched_template: str | None = None\n    if task_params.tools:\n        original_template: str | None = getattr(tokenizer, \"chat_template\", None)\n        if isinstance(original_template, str):\n            patched_template = _patch_lossy_chat_template(original_template)\n            if patched_template is not None:\n                logger.info(\n                    \"Patched lossy chat template (removed inner_type length guard)\"\n                )\n\n    prompt: str = tokenizer.apply_chat_template(\n        formatted_messages,\n        tokenize=False,\n        add_generation_prompt=True,\n        tools=task_params.tools,\n        **({\"chat_template\": patched_template} if patched_template is not None else {}),\n        **extra_kwargs,\n    )\n\n    if task_params.tools and _schemas_lost_in_prompt(prompt, task_params.tools):\n        logger.warning(\"Chat template lost nested tool schemas even after patching\")\n\n    if partial_assistant_content:\n        prompt += partial_assistant_content\n\n    logger.info(prompt)\n\n    return prompt\n\n\ndef detect_thinking_prompt_suffix(prompt: str, tokenizer: TokenizerWrapper) -> bool:\n    \"\"\"\n    Detect if prompt ends with a thinking opening tag that should be\n    prepended to the output stream.\n    \"\"\"\n    think_token = tokenizer.think_start\n\n    return think_token is not None and prompt.rstrip().endswith(think_token)\n\n\ndef fix_unmatched_think_end_tokens(\n    tokens: mx.array, tokenizer: TokenizerWrapper\n) -> mx.array:\n    if not tokenizer.has_thinking:\n        return tokens\n    assert tokenizer.think_start_id\n    assert tokenizer.think_end_id\n    think_start_id: int = tokenizer.think_start_id\n    think_end_id: int = tokenizer.think_end_id\n    token_list: list[int] = cast(list[int], tokens.tolist())\n    result: list[int] = []\n    depth = 0\n    for token in token_list:\n        if token == think_start_id:\n            depth += 1\n        elif token == think_end_id:\n            if depth == 0:\n                result.append(think_start_id)\n            else:\n                depth -= 1\n        result.append(token)\n    return mx.array(result)\n\n\nclass NullKVCache(KVCache):\n    \"\"\"\n    A KVCache that pretends to exist but holds zero tokens.\n    It satisfies .state/.meta_state and never allocates real keys/values.\n    \"\"\"\n\n    def __init__(self, dtype: mx.Dtype = mx.float16):\n        super().__init__()\n        # zero-length K/V so shapes/dtypes are defined but empty\n        self.keys = mx.zeros((1, 1, 0, 1), dtype=dtype)\n        self.values = mx.zeros((1, 1, 0, 1), dtype=dtype)\n        self.offset = 0\n\n    @property\n    def state(self) -> tuple[mx.array, mx.array]:\n        # matches what mx.save_safetensors / mx.eval expect\n        return self.keys, self.values\n\n    @state.setter\n    def state(self, v: tuple[mx.array, mx.array]) -> None:\n        raise NotImplementedError(\"We should not be setting a NullKVCache.\")\n\n\ndef mlx_force_oom(size: int = 200000) -> None:\n    \"\"\"\n    Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.\n    \"\"\"\n    mx.set_default_device(mx.gpu)\n    a = mx.random.uniform(shape=(size, size), dtype=mx.float32)\n    b = mx.random.uniform(shape=(size, size), dtype=mx.float32)\n    mx.eval(a, b)\n    c = mx.matmul(a, b)\n    d = mx.matmul(a, c)\n    e = mx.matmul(b, c)\n    f = mx.sigmoid(d + e)\n    mx.eval(f)\n\n\ndef set_wired_limit_for_model(model_size: Memory):\n    \"\"\"\n    A context manager to temporarily change the wired limit.\n\n    Note, the wired limit should not be changed during an async eval.  If an\n    async eval could be running pass in the streams to synchronize with prior\n    to exiting the context manager.\n    \"\"\"\n    if not mx.metal.is_available():\n        return\n\n    max_rec_size = Memory.from_bytes(\n        int(mx.device_info()[\"max_recommended_working_set_size\"])\n    )\n    if model_size > 0.9 * max_rec_size:\n        logger.warning(\n            f\"Generating with a model that requires {model_size.in_float_mb:.1f} MB \"\n            f\"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} \"\n            \"MB. This can be slow. See the documentation for possible work-arounds: \"\n            \"https://github.com/ml-explore/mlx-lm/tree/main#large-models\"\n        )\n    mx.set_wired_limit(max_rec_size.in_bytes)\n    logger.info(f\"Wired limit set to {max_rec_size}.\")\n\n\ndef mlx_cleanup(\n    model: Model | None, tokenizer: TokenizerWrapper | None, group: Group | None\n) -> None:\n    del model, tokenizer, group\n    mx.clear_cache()\n    import gc\n\n    gc.collect()\n\n\ndef mx_any(bool_: bool, group: Group | None) -> bool:\n    if group is None:\n        return bool_\n    num_true = mx.distributed.all_sum(\n        mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))\n    )\n    mx.eval(num_true)\n    return num_true.item() > 0\n\n\ndef mx_barrier(group: Group | None):\n    if group is None:\n        return\n    mx.eval(\n        mx.distributed.all_sum(\n            mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))\n        )\n    )\n\n\ndef _parse_kimi_tool_calls(text: str):\n    import regex as re\n\n    # kimi has a fixed function naming scheme, with a json formatted arg\n    #   functions.multiply:0<|tool_call_argument_begin|>{\"a\": 2, \"b\": 3}\n    _func_name_regex = re.compile(\n        r\"^\\s*((?:functions\\.)?(.+?):\\d+)\\s*<\\|tool_call_argument_begin\\|>\", re.DOTALL\n    )\n    _func_arg_regex = re.compile(r\"<\\|tool_call_argument_begin\\|>\\s*(.*)\\s*\", re.DOTALL)\n    _tool_call_split_regex = re.compile(\n        r\"<\\|tool_call_begin\\|>(.*?)<\\|tool_call_end\\|>\", re.DOTALL\n    )\n\n    def _parse_single_tool(text: str) -> dict[str, Any]:\n        func_name_match = _func_name_regex.search(text)\n        if func_name_match is None:\n            raise ValueError(\"No tool call found.\")\n        tool_call_id = func_name_match.group(1)  # e.g. \"functions.get_weather:0\"\n        func_name = func_name_match.group(2)  # e.g. \"get_weather\"\n\n        func_args_match = _func_arg_regex.search(text)\n        if func_args_match is None:\n            raise ValueError(\"No tool call arguments found.\")\n        func_args = func_args_match.group(1)\n        try:\n            arg_dct = json.loads(func_args)  # pyright: ignore[reportAny]\n        except Exception:\n            arg_dct = None\n\n        return dict(id=tool_call_id, name=func_name, arguments=arg_dct)\n\n    tool_matches = _tool_call_split_regex.findall(text)\n    if tool_matches:\n        return [_parse_single_tool(match) for match in tool_matches]  # pyright: ignore[reportAny]\n    else:\n        return [_parse_single_tool(text)]\n\n\ndef mx_all_gather_tasks(\n    tasks: list[TextGeneration],\n    group: mx.distributed.Group | None,\n) -> tuple[list[TextGeneration], list[TextGeneration]]:\n    def encode_task_id(task_id: TaskId) -> list[int]:\n        utf8_task_id = task_id.encode()\n        return [\n            int.from_bytes(utf8_task_id[i : i + 1]) for i in range(len(utf8_task_id))\n        ]\n\n    def decode_task_id(encoded_task_id: list[int]) -> TaskId:\n        return TaskId(\n            bytes.decode(b\"\".join((x).to_bytes(length=1) for x in encoded_task_id))\n        )\n\n    uuid_byte_length = 36\n\n    n_tasks = len(tasks)\n    all_counts = cast(\n        list[int],\n        mx.distributed.all_gather(mx.array([n_tasks]), group=group).tolist(),\n    )\n    max_tasks = max(all_counts)\n    world_size: int = 1 if group is None else group.size()\n\n    if max_tasks == 0:\n        return [], []\n\n    padded = [encode_task_id(task.task_id) for task in tasks] + [\n        [0] * uuid_byte_length\n    ] * (max_tasks - n_tasks)\n\n    assert all(len(encoded_task_id) == uuid_byte_length for encoded_task_id in padded)\n\n    gathered = cast(\n        list[list[list[int]]],\n        mx.distributed.all_gather(mx.array(padded), group=group)\n        .reshape(world_size, max_tasks, -1)\n        .tolist(),\n    )\n    all_task_ids: list[list[TaskId]] = [\n        [decode_task_id(encoded_task_id) for encoded_task_id in rank_tasks[:count]]\n        for rank_tasks, count in zip(gathered, all_counts, strict=True)\n    ]\n\n    agreed_ids = set[TaskId].intersection(*(set(tids) for tids in all_task_ids))\n\n    local_tasks = {task.task_id: task for task in tasks}\n    agreed = [local_tasks[tid] for tid in sorted(agreed_ids)]\n    different = [task for task in tasks if task.task_id not in agreed_ids]\n    return agreed, different\n"
  },
  {
    "path": "src/exo/worker/main.py",
    "content": "from collections import defaultdict\nfrom datetime import datetime, timezone\n\nimport anyio\nfrom anyio import fail_after\nfrom loguru import logger\n\nfrom exo.api.types import ImageEditsTaskParams\nfrom exo.download.download_utils import resolve_model_in_path\nfrom exo.shared.apply import apply\nfrom exo.shared.models.model_cards import ModelId\nfrom exo.shared.types.commands import (\n    ForwarderCommand,\n    ForwarderDownloadCommand,\n    StartDownload,\n)\nfrom exo.shared.types.common import CommandId, NodeId, SystemId\nfrom exo.shared.types.events import (\n    Event,\n    IndexedEvent,\n    InputChunkReceived,\n    NodeDownloadProgress,\n    NodeGatheredInfo,\n    TaskCreated,\n    TaskStatusUpdated,\n    TopologyEdgeCreated,\n    TopologyEdgeDeleted,\n)\nfrom exo.shared.types.multiaddr import Multiaddr\nfrom exo.shared.types.state import State\nfrom exo.shared.types.tasks import (\n    CancelTask,\n    CreateRunner,\n    DownloadModel,\n    ImageEdits,\n    Shutdown,\n    Task,\n    TaskStatus,\n)\nfrom exo.shared.types.topology import Connection, SocketConnection\nfrom exo.shared.types.worker.downloads import DownloadCompleted\nfrom exo.shared.types.worker.runners import RunnerId\nfrom exo.utils.channels import Receiver, Sender, channel\nfrom exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer\nfrom exo.utils.info_gatherer.net_profile import check_reachable\nfrom exo.utils.keyed_backoff import KeyedBackoff\nfrom exo.utils.task_group import TaskGroup\nfrom exo.worker.plan import plan\nfrom exo.worker.runner.runner_supervisor import RunnerSupervisor\n\n\nclass Worker:\n    def __init__(\n        self,\n        node_id: NodeId,\n        *,\n        event_receiver: Receiver[IndexedEvent],\n        event_sender: Sender[Event],\n        # This is for requesting updates. It doesn't need to be a general command sender right now,\n        # but I think it's the correct way to be thinking about commands\n        command_sender: Sender[ForwarderCommand],\n        download_command_sender: Sender[ForwarderDownloadCommand],\n    ):\n        self.node_id: NodeId = node_id\n        self.event_receiver = event_receiver\n        self.event_sender = event_sender\n        self.command_sender = command_sender\n        self.download_command_sender = download_command_sender\n\n        self.state: State = State()\n        self.runners: dict[RunnerId, RunnerSupervisor] = {}\n        self._tg: TaskGroup = TaskGroup()\n\n        self._system_id = SystemId()\n\n        # Buffer for input image chunks (for image editing)\n        self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}\n        self.input_chunk_counts: dict[CommandId, int] = {}\n\n        self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)\n\n    async def run(self):\n        logger.info(\"Starting Worker\")\n\n        info_send, info_recv = channel[GatheredInfo]()\n        info_gatherer: InfoGatherer = InfoGatherer(info_send)\n\n        try:\n            async with self._tg as tg:\n                tg.start_soon(info_gatherer.run)\n                tg.start_soon(self._forward_info, info_recv)\n                tg.start_soon(self.plan_step)\n                tg.start_soon(self._event_applier)\n                tg.start_soon(self._poll_connection_updates)\n        finally:\n            # Actual shutdown code - waits for all tasks to complete before executing.\n            logger.info(\"Stopping Worker\")\n            self.event_sender.close()\n            self.command_sender.close()\n            self.download_command_sender.close()\n            for runner in self.runners.values():\n                runner.shutdown()\n\n    async def _forward_info(self, recv: Receiver[GatheredInfo]):\n        with recv as info_stream:\n            async for info in info_stream:\n                await self.event_sender.send(\n                    NodeGatheredInfo(\n                        node_id=self.node_id,\n                        when=str(datetime.now(tz=timezone.utc)),\n                        info=info,\n                    )\n                )\n\n    async def _event_applier(self):\n        with self.event_receiver as events:\n            async for event in events:\n                # 2. for each event, apply it to the state\n                self.state = apply(self.state, event=event)\n                event = event.event\n\n                # Buffer input image chunks for image editing\n                if isinstance(event, InputChunkReceived):\n                    cmd_id = event.command_id\n                    if cmd_id not in self.input_chunk_buffer:\n                        self.input_chunk_buffer[cmd_id] = {}\n                        self.input_chunk_counts[cmd_id] = event.chunk.total_chunks\n\n                    self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (\n                        event.chunk.data\n                    )\n\n    async def plan_step(self):\n        while True:\n            await anyio.sleep(0.1)\n            task: Task | None = plan(\n                self.node_id,\n                self.runners,\n                self.state.downloads,\n                self.state.instances,\n                self.state.runners,\n                self.state.tasks,\n                self.input_chunk_buffer,\n                self.input_chunk_counts,\n            )\n            if task is None:\n                continue\n\n            # Gate DownloadModel on backoff BEFORE emitting TaskCreated\n            # to prevent flooding the event log with useless events\n            if isinstance(task, DownloadModel):\n                model_id = task.shard_metadata.model_card.model_id\n                if not self._download_backoff.should_proceed(model_id):\n                    continue\n\n            logger.info(f\"Worker plan: {task.__class__.__name__}\")\n            assert task.task_status\n            await self.event_sender.send(TaskCreated(task_id=task.task_id, task=task))\n\n            # lets not kill the worker if a runner is unresponsive\n            match task:\n                case CreateRunner():\n                    self._create_supervisor(task)\n                    await self.event_sender.send(\n                        TaskStatusUpdated(\n                            task_id=task.task_id, task_status=TaskStatus.Complete\n                        )\n                    )\n                case DownloadModel(shard_metadata=shard):\n                    model_id = shard.model_card.model_id\n                    self._download_backoff.record_attempt(model_id)\n\n                    found_path = resolve_model_in_path(model_id)\n                    if found_path is not None:\n                        logger.info(\n                            f\"Model {model_id} found in EXO_MODELS_PATH at {found_path}\"\n                        )\n                        await self.event_sender.send(\n                            NodeDownloadProgress(\n                                download_progress=DownloadCompleted(\n                                    node_id=self.node_id,\n                                    shard_metadata=shard,\n                                    model_directory=str(found_path),\n                                    total=shard.model_card.storage_size,\n                                    read_only=True,\n                                )\n                            )\n                        )\n                        await self.event_sender.send(\n                            TaskStatusUpdated(\n                                task_id=task.task_id,\n                                task_status=TaskStatus.Complete,\n                            )\n                        )\n                    else:\n                        await self.download_command_sender.send(\n                            ForwarderDownloadCommand(\n                                origin=self._system_id,\n                                command=StartDownload(\n                                    target_node_id=self.node_id,\n                                    shard_metadata=shard,\n                                ),\n                            )\n                        )\n                        await self.event_sender.send(\n                            TaskStatusUpdated(\n                                task_id=task.task_id,\n                                task_status=TaskStatus.Running,\n                            )\n                        )\n                case Shutdown(runner_id=runner_id):\n                    runner = self.runners.pop(runner_id)\n                    try:\n                        with fail_after(3):\n                            await runner.start_task(task)\n                    except TimeoutError:\n                        await self.event_sender.send(\n                            TaskStatusUpdated(\n                                task_id=task.task_id, task_status=TaskStatus.TimedOut\n                            )\n                        )\n                    finally:\n                        runner.shutdown()\n                case CancelTask(\n                    cancelled_task_id=cancelled_task_id, runner_id=runner_id\n                ):\n                    await self.runners[runner_id].cancel_task(cancelled_task_id)\n                    await self.event_sender.send(\n                        TaskStatusUpdated(\n                            task_id=task.task_id, task_status=TaskStatus.Complete\n                        )\n                    )\n                case ImageEdits() if task.task_params.total_input_chunks > 0:\n                    # Assemble image from chunks and inject into task\n                    cmd_id = task.command_id\n                    chunks = self.input_chunk_buffer.get(cmd_id, {})\n                    assembled = \"\".join(chunks[i] for i in range(len(chunks)))\n                    logger.info(\n                        f\"Assembled input image from {len(chunks)} chunks, \"\n                        f\"total size: {len(assembled)} bytes\"\n                    )\n                    # Create modified task with assembled image data\n                    modified_task = ImageEdits(\n                        task_id=task.task_id,\n                        command_id=task.command_id,\n                        instance_id=task.instance_id,\n                        task_status=task.task_status,\n                        task_params=ImageEditsTaskParams(\n                            image_data=assembled,\n                            total_input_chunks=task.task_params.total_input_chunks,\n                            prompt=task.task_params.prompt,\n                            model=task.task_params.model,\n                            n=task.task_params.n,\n                            quality=task.task_params.quality,\n                            output_format=task.task_params.output_format,\n                            response_format=task.task_params.response_format,\n                            size=task.task_params.size,\n                            image_strength=task.task_params.image_strength,\n                            bench=task.task_params.bench,\n                            stream=task.task_params.stream,\n                            partial_images=task.task_params.partial_images,\n                            advanced_params=task.task_params.advanced_params,\n                        ),\n                    )\n                    # Cleanup buffers\n                    if cmd_id in self.input_chunk_buffer:\n                        del self.input_chunk_buffer[cmd_id]\n                    if cmd_id in self.input_chunk_counts:\n                        del self.input_chunk_counts[cmd_id]\n                    await self._start_runner_task(modified_task)\n                case task:\n                    await self._start_runner_task(task)\n\n    def shutdown(self):\n        self._tg.cancel_tasks()\n\n    async def _start_runner_task(self, task: Task):\n        if (instance := self.state.instances.get(task.instance_id)) is not None:\n            await self.runners[\n                instance.shard_assignments.node_to_runner[self.node_id]\n            ].start_task(task)\n\n    def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:\n        \"\"\"Creates and stores a new AssignedRunner with initial downloading status.\"\"\"\n        runner = RunnerSupervisor.create(\n            bound_instance=task.bound_instance,\n            event_sender=self.event_sender.clone(),\n        )\n        self.runners[task.bound_instance.bound_runner_id] = runner\n        self._tg.start_soon(runner.run)\n        return runner\n\n    async def _poll_connection_updates(self):\n        while True:\n            edges = set(\n                conn.edge for conn in self.state.topology.out_edges(self.node_id)\n            )\n            conns: defaultdict[NodeId, set[str]] = defaultdict(set)\n            async for ip, nid in check_reachable(\n                self.state.topology,\n                self.node_id,\n                self.state.node_network,\n            ):\n                if ip in conns[nid]:\n                    continue\n                conns[nid].add(ip)\n                edge = SocketConnection(\n                    # nonsense multiaddr\n                    sink_multiaddr=Multiaddr(address=f\"/ip4/{ip}/tcp/52415\")\n                    if \".\" in ip\n                    # nonsense multiaddr\n                    else Multiaddr(address=f\"/ip6/{ip}/tcp/52415\"),\n                )\n                if edge not in edges:\n                    logger.debug(f\"ping discovered {edge=}\")\n                    await self.event_sender.send(\n                        TopologyEdgeCreated(\n                            conn=Connection(source=self.node_id, sink=nid, edge=edge)\n                        )\n                    )\n\n            for conn in self.state.topology.out_edges(self.node_id):\n                if not isinstance(conn.edge, SocketConnection):\n                    continue\n                # ignore mDNS discovered connections\n                if conn.edge.sink_multiaddr.port != 52415:\n                    continue\n                if (\n                    conn.sink not in conns\n                    or conn.edge.sink_multiaddr.ip_address not in conns[conn.sink]\n                ):\n                    logger.debug(f\"ping failed to discover {conn=}\")\n                    await self.event_sender.send(TopologyEdgeDeleted(conn=conn))\n\n            await anyio.sleep(10)\n"
  },
  {
    "path": "src/exo/worker/plan.py",
    "content": "# pyright: reportUnusedImport = false\n\nfrom collections.abc import Mapping, Sequence\n\nfrom exo.shared.types.common import CommandId, NodeId\nfrom exo.shared.types.tasks import (\n    CancelTask,\n    ConnectToGroup,\n    CreateRunner,\n    DownloadModel,\n    ImageEdits,\n    ImageGeneration,\n    LoadModel,\n    Shutdown,\n    StartWarmup,\n    Task,\n    TaskId,\n    TaskStatus,\n    TextGeneration,\n)\nfrom exo.shared.types.worker.downloads import (\n    DownloadCompleted,\n    DownloadFailed,\n    DownloadOngoing,\n    DownloadProgress,\n)\nfrom exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId\nfrom exo.shared.types.worker.runners import (\n    RunnerConnected,\n    RunnerConnecting,\n    RunnerFailed,\n    RunnerId,\n    RunnerIdle,\n    RunnerLoaded,\n    RunnerLoading,\n    RunnerReady,\n    RunnerRunning,\n    RunnerStatus,\n    RunnerWarmingUp,\n)\nfrom exo.worker.runner.runner_supervisor import RunnerSupervisor\n\n\ndef plan(\n    node_id: NodeId,\n    # Runners is expected to be FRESH and so should not come from state\n    runners: Mapping[RunnerId, RunnerSupervisor],\n    global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],\n    instances: Mapping[InstanceId, Instance],\n    all_runners: Mapping[RunnerId, RunnerStatus],  # all global\n    tasks: Mapping[TaskId, Task],\n    input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,\n    input_chunk_counts: Mapping[CommandId, int] | None = None,\n) -> Task | None:\n    # Python short circuiting OR logic should evaluate these sequentially.\n    return (\n        _cancel_tasks(runners, tasks)\n        or _kill_runner(runners, all_runners, instances)\n        or _create_runner(node_id, runners, instances)\n        or _model_needs_download(node_id, runners, global_download_status)\n        or _init_distributed_backend(runners, all_runners)\n        or _load_model(runners, all_runners, global_download_status)\n        or _ready_to_warmup(runners, all_runners)\n        or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})\n    )\n\n\ndef _kill_runner(\n    runners: Mapping[RunnerId, RunnerSupervisor],\n    all_runners: Mapping[RunnerId, RunnerStatus],\n    instances: Mapping[InstanceId, Instance],\n) -> Shutdown | None:\n    for runner in runners.values():\n        runner_id = runner.bound_instance.bound_runner_id\n        if (instance_id := runner.bound_instance.instance.instance_id) not in instances:\n            return Shutdown(instance_id=instance_id, runner_id=runner_id)\n\n        for (\n            global_runner_id\n        ) in runner.bound_instance.instance.shard_assignments.node_to_runner.values():\n            if runner_id == global_runner_id:\n                continue\n\n            if isinstance(all_runners.get(global_runner_id, None), RunnerFailed):\n                return Shutdown(\n                    instance_id=instance_id,\n                    runner_id=runner_id,\n                )\n\n\ndef _create_runner(\n    node_id: NodeId,\n    runners: Mapping[RunnerId, RunnerSupervisor],\n    instances: Mapping[InstanceId, Instance],\n) -> CreateRunner | None:\n    for instance in instances.values():\n        runner_id = instance.shard_assignments.node_to_runner.get(node_id, None)\n        if runner_id is None:\n            continue\n\n        if runner_id in runners:\n            continue\n\n        shard = instance.shard(runner_id)\n        assert shard is not None\n\n        return CreateRunner(\n            instance_id=instance.instance_id,\n            bound_instance=BoundInstance(\n                instance=instance, bound_runner_id=runner_id, bound_node_id=node_id\n            ),\n        )\n\n\ndef _model_needs_download(\n    node_id: NodeId,\n    runners: Mapping[RunnerId, RunnerSupervisor],\n    global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],\n) -> DownloadModel | None:\n    local_downloads = global_download_status.get(node_id, [])\n    download_status = {\n        dp.shard_metadata.model_card.model_id: dp for dp in local_downloads\n    }\n\n    for runner in runners.values():\n        model_id = runner.bound_instance.bound_shard.model_card.model_id\n        if isinstance(runner.status, RunnerIdle) and (\n            model_id not in download_status\n            or not isinstance(\n                download_status[model_id],\n                (DownloadOngoing, DownloadCompleted, DownloadFailed),\n            )\n        ):\n            # We don't invalidate download_status randomly in case a file gets deleted on disk\n            return DownloadModel(\n                instance_id=runner.bound_instance.instance.instance_id,\n                shard_metadata=runner.bound_instance.bound_shard,\n            )\n\n\ndef _init_distributed_backend(\n    runners: Mapping[RunnerId, RunnerSupervisor],\n    all_runners: Mapping[RunnerId, RunnerStatus],\n):\n    for runner in runners.values():\n        instance = runner.bound_instance.instance\n        shard_assignments = instance.shard_assignments\n\n        is_single_node_instance = len(shard_assignments.runner_to_shard) == 1\n        if is_single_node_instance:\n            continue\n\n        runner_is_idle = isinstance(runner.status, RunnerIdle)\n        all_runners_connecting = all(\n            isinstance(\n                all_runners.get(global_runner_id),\n                (RunnerConnecting, RunnerIdle),\n            )\n            for global_runner_id in shard_assignments.runner_to_shard\n        )\n\n        if not (runner_is_idle and all_runners_connecting):\n            continue\n\n        runner_id = runner.bound_instance.bound_runner_id\n\n        shard = runner.bound_instance.bound_shard\n        device_rank = shard.device_rank\n        world_size = shard.world_size\n\n        assert device_rank < world_size\n        assert device_rank >= 0\n\n        accepting_ranks = device_rank < world_size - 1\n\n        # Rank = n-1\n        connecting_rank_ready = device_rank == world_size - 1 and all(\n            isinstance(all_runners.get(global_runner_id, None), RunnerConnecting)\n            for global_runner_id in shard_assignments.runner_to_shard\n            if global_runner_id != runner_id\n        )\n\n        if not (accepting_ranks or connecting_rank_ready):\n            continue\n\n        return ConnectToGroup(instance_id=instance.instance_id)\n\n    return None\n\n\ndef _load_model(\n    runners: Mapping[RunnerId, RunnerSupervisor],\n    all_runners: Mapping[RunnerId, RunnerStatus],\n    global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],\n) -> LoadModel | None:\n    for runner in runners.values():\n        instance = runner.bound_instance.instance\n        shard_assignments = instance.shard_assignments\n\n        all_local_downloads_complete = all(\n            nid in global_download_status\n            and any(\n                isinstance(dp, DownloadCompleted)\n                and dp.shard_metadata.model_card.model_id == shard_assignments.model_id\n                for dp in global_download_status[nid]\n            )\n            for nid in shard_assignments.node_to_runner\n        )\n        if not all_local_downloads_complete:\n            continue\n\n        is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1\n        if is_single_node_instance and isinstance(runner.status, RunnerIdle):\n            return LoadModel(instance_id=instance.instance_id)\n\n        is_runner_waiting = isinstance(runner.status, RunnerConnected)\n\n        all_ready_for_model = all(\n            isinstance(\n                all_runners.get(global_runner_id, None),\n                (RunnerConnected, RunnerLoading, RunnerLoaded),\n            )\n            for global_runner_id in shard_assignments.runner_to_shard\n        )\n\n        if is_runner_waiting and all_ready_for_model:\n            return LoadModel(instance_id=instance.instance_id)\n\n    return None\n\n\ndef _ready_to_warmup(\n    runners: Mapping[RunnerId, RunnerSupervisor],\n    all_runners: Mapping[RunnerId, RunnerStatus],\n) -> StartWarmup | None:\n    for runner in runners.values():\n        instance = runner.bound_instance.instance\n        shard_assignments = instance.shard_assignments\n        shard = runner.bound_instance.bound_shard\n        device_rank = shard.device_rank\n        runner_id = runner.bound_instance.bound_runner_id\n        world_size = shard.world_size\n\n        is_runner_loaded = isinstance(runner.status, RunnerLoaded)\n\n        assert device_rank < world_size\n        assert device_rank >= 0\n\n        # Rank != 0\n        accepting_ranks_ready = device_rank > 0 and all(\n            isinstance(\n                all_runners.get(global_runner_id, None),\n                (RunnerLoaded, RunnerWarmingUp),\n            )\n            for global_runner_id in shard_assignments.runner_to_shard\n        )\n\n        # Rank = 0\n        connecting_rank_ready = device_rank == 0 and all(\n            isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp)\n            for global_runner_id in shard_assignments.runner_to_shard\n            if global_runner_id != runner_id\n        )\n\n        if is_runner_loaded and (accepting_ranks_ready or connecting_rank_ready):\n            return StartWarmup(instance_id=instance.instance_id)\n\n    return None\n\n\ndef _pending_tasks(\n    runners: Mapping[RunnerId, RunnerSupervisor],\n    tasks: Mapping[TaskId, Task],\n    all_runners: Mapping[RunnerId, RunnerStatus],\n    input_chunk_buffer: Mapping[CommandId, dict[int, str]],\n) -> Task | None:\n    for task in tasks.values():\n        # for now, just forward chat completions\n        # TODO(ciaran): do this better!\n        if not isinstance(task, (TextGeneration, ImageGeneration, ImageEdits)):\n            continue\n        if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):\n            continue\n\n        # For ImageEdits tasks, verify all input chunks have been received\n        if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:\n            cmd_id = task.command_id\n            expected = task.task_params.total_input_chunks\n            received = len(input_chunk_buffer.get(cmd_id, {}))\n            if received < expected:\n                continue  # Wait for all chunks to arrive\n\n        for runner in runners.values():\n            if task.instance_id != runner.bound_instance.instance.instance_id:\n                continue\n\n            # the task status _should_ be set to completed by the LAST runner\n            # it is currently set by the first\n            # this is definitely a hack\n            if task.task_id in runner.completed or task.task_id in runner.in_progress:\n                continue\n\n            if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(\n                isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))\n                for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard\n            ):\n                return task\n\n\ndef _cancel_tasks(\n    runners: Mapping[RunnerId, RunnerSupervisor],\n    tasks: Mapping[TaskId, Task],\n) -> Task | None:\n    for task in tasks.values():\n        if task.task_status != TaskStatus.Cancelled:\n            continue\n        for runner_id, runner in runners.items():\n            if task.instance_id != runner.bound_instance.instance.instance_id:\n                continue\n            if task.task_id in runner.cancelled:\n                continue\n            return CancelTask(\n                instance_id=task.instance_id,\n                cancelled_task_id=task.task_id,\n                runner_id=runner_id,\n            )\n"
  },
  {
    "path": "src/exo/worker/runner/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/worker/runner/bootstrap.py",
    "content": "import os\nimport resource\n\nimport loguru\n\nfrom exo.shared.types.events import Event, RunnerStatusUpdated\nfrom exo.shared.types.tasks import Task, TaskId\nfrom exo.shared.types.worker.instances import BoundInstance\nfrom exo.shared.types.worker.runners import RunnerFailed\nfrom exo.utils.channels import ClosedResourceError, MpReceiver, MpSender\n\nlogger: \"loguru.Logger\" = loguru.logger\n\n\ndef entrypoint(\n    bound_instance: BoundInstance,\n    event_sender: MpSender[Event],\n    task_receiver: MpReceiver[Task],\n    cancel_receiver: MpReceiver[TaskId],\n    _logger: \"loguru.Logger\",\n) -> None:\n    global logger\n    logger = _logger\n\n    soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)\n    resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))\n\n    fast_synch_override = os.environ.get(\"EXO_FAST_SYNCH\")\n    if fast_synch_override != \"off\":\n        os.environ[\"MLX_METAL_FAST_SYNCH\"] = \"1\"\n    else:\n        os.environ[\"MLX_METAL_FAST_SYNCH\"] = \"0\"\n\n    logger.info(f\"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}\")\n\n    # Import main after setting global logger - this lets us just import logger from this module\n    try:\n        if bound_instance.is_image_model:\n            from exo.worker.runner.image_models.runner import Runner as ImageRunner\n\n            runner = ImageRunner(\n                bound_instance, event_sender, task_receiver, cancel_receiver\n            )\n            runner.main()\n        else:\n            from exo.worker.runner.llm_inference.runner import Runner\n\n            runner = Runner(\n                bound_instance, event_sender, task_receiver, cancel_receiver\n            )\n            runner.main()\n\n    except ClosedResourceError:\n        logger.warning(\"Runner communication closed unexpectedly\")\n    except Exception as e:\n        logger.opt(exception=e).warning(\n            f\"Runner {bound_instance.bound_runner_id} crashed with critical exception {e}\"\n        )\n        event_sender.send(\n            RunnerStatusUpdated(\n                runner_id=bound_instance.bound_runner_id,\n                runner_status=RunnerFailed(error_message=str(e)),\n            )\n        )\n    finally:\n        try:\n            event_sender.close()\n            task_receiver.close()\n        finally:\n            event_sender.join()\n            task_receiver.join()\n            logger.info(\"bye from the runner\")\n"
  },
  {
    "path": "src/exo/worker/runner/image_models/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/worker/runner/image_models/runner.py",
    "content": "import base64\nimport time\nfrom typing import TYPE_CHECKING, Literal\n\nimport mlx.core as mx\n\nfrom exo.api.types import ImageGenerationStats\nfrom exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED\nfrom exo.shared.models.model_cards import ModelTask\nfrom exo.shared.tracing import clear_trace_buffer, get_trace_buffer\nfrom exo.shared.types.chunks import ErrorChunk, ImageChunk\nfrom exo.shared.types.common import CommandId, ModelId\nfrom exo.shared.types.events import (\n    ChunkGenerated,\n    Event,\n    RunnerStatusUpdated,\n    TaskAcknowledged,\n    TaskStatusUpdated,\n    TraceEventData,\n    TracesCollected,\n)\nfrom exo.shared.types.tasks import (\n    CANCEL_ALL_TASKS,\n    ConnectToGroup,\n    ImageEdits,\n    ImageGeneration,\n    LoadModel,\n    Shutdown,\n    StartWarmup,\n    Task,\n    TaskId,\n    TaskStatus,\n)\nfrom exo.shared.types.worker.instances import BoundInstance\nfrom exo.shared.types.worker.runner_response import (\n    ImageGenerationResponse,\n    PartialImageResponse,\n)\nfrom exo.shared.types.worker.runners import (\n    RunnerConnected,\n    RunnerConnecting,\n    RunnerFailed,\n    RunnerIdle,\n    RunnerLoaded,\n    RunnerLoading,\n    RunnerReady,\n    RunnerRunning,\n    RunnerShutdown,\n    RunnerShuttingDown,\n    RunnerStatus,\n    RunnerWarmingUp,\n)\nfrom exo.shared.types.worker.shards import (\n    CfgShardMetadata,\n    PipelineShardMetadata,\n    ShardMetadata,\n)\nfrom exo.utils.channels import MpReceiver, MpSender\nfrom exo.worker.engines.image import (\n    DistributedImageModel,\n    generate_image,\n    initialize_image_model,\n    warmup_image_generator,\n)\nfrom exo.worker.engines.mlx.utils_mlx import (\n    initialize_mlx,\n)\nfrom exo.worker.runner.bootstrap import logger\n\n\ndef _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:\n    \"\"\"Check if this node is the primary output node for image generation.\n\n    For CFG models: the last pipeline stage in CFG group 0 (positive prompt).\n    For non-CFG models: the last pipeline stage.\n    \"\"\"\n    if isinstance(shard_metadata, CfgShardMetadata):\n        is_pipeline_last = (\n            shard_metadata.pipeline_rank == shard_metadata.pipeline_world_size - 1\n        )\n        return is_pipeline_last and shard_metadata.cfg_rank == 0\n    elif isinstance(shard_metadata, PipelineShardMetadata):\n        return shard_metadata.device_rank == shard_metadata.world_size - 1\n    return False\n\n\ndef _process_image_response(\n    response: ImageGenerationResponse | PartialImageResponse,\n    command_id: CommandId,\n    shard_metadata: ShardMetadata,\n    event_sender: MpSender[Event],\n    image_index: int,\n) -> None:\n    \"\"\"Process a single image response and send chunks.\"\"\"\n    encoded_data = base64.b64encode(response.image_data).decode(\"utf-8\")\n    is_partial = isinstance(response, PartialImageResponse)\n    # Extract stats from final ImageGenerationResponse if available\n    stats = response.stats if isinstance(response, ImageGenerationResponse) else None\n    _send_image_chunk(\n        encoded_data=encoded_data,\n        command_id=command_id,\n        model_id=shard_metadata.model_card.model_id,\n        event_sender=event_sender,\n        image_index=response.image_index,\n        is_partial=is_partial,\n        partial_index=response.partial_index if is_partial else None,\n        total_partials=response.total_partials if is_partial else None,\n        stats=stats,\n        image_format=response.format,\n    )\n\n\ndef _send_traces_if_enabled(\n    event_sender: MpSender[Event],\n    task_id: TaskId,\n    rank: int,\n) -> None:\n    if not EXO_TRACING_ENABLED:\n        return\n\n    traces = get_trace_buffer()\n    if traces:\n        trace_data = [\n            TraceEventData(\n                name=t.name,\n                start_us=t.start_us,\n                duration_us=t.duration_us,\n                rank=t.rank,\n                category=t.category,\n            )\n            for t in traces\n        ]\n        event_sender.send(\n            TracesCollected(\n                task_id=task_id,\n                rank=rank,\n                traces=trace_data,\n            )\n        )\n    clear_trace_buffer()\n\n\ndef _send_image_chunk(\n    encoded_data: str,\n    command_id: CommandId,\n    model_id: ModelId,\n    event_sender: MpSender[Event],\n    image_index: int,\n    is_partial: bool,\n    partial_index: int | None = None,\n    total_partials: int | None = None,\n    stats: ImageGenerationStats | None = None,\n    image_format: Literal[\"png\", \"jpeg\", \"webp\"] | None = None,\n) -> None:\n    \"\"\"Send base64-encoded image data as chunks via events.\"\"\"\n    data_chunks = [\n        encoded_data[i : i + EXO_MAX_CHUNK_SIZE]\n        for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE)\n    ]\n    total_chunks = len(data_chunks)\n    for chunk_index, chunk_data in enumerate(data_chunks):\n        # Only include stats on the last chunk of the final image\n        chunk_stats = (\n            stats if chunk_index == total_chunks - 1 and not is_partial else None\n        )\n        event_sender.send(\n            ChunkGenerated(\n                command_id=command_id,\n                chunk=ImageChunk(\n                    model=model_id,\n                    data=chunk_data,\n                    chunk_index=chunk_index,\n                    total_chunks=total_chunks,\n                    image_index=image_index,\n                    is_partial=is_partial,\n                    partial_index=partial_index,\n                    total_partials=total_partials,\n                    stats=chunk_stats,\n                    format=image_format,\n                ),\n            )\n        )\n\n\nclass Runner:\n    def __init__(\n        self,\n        bound_instance: BoundInstance,\n        event_sender: MpSender[Event],\n        task_receiver: MpReceiver[Task],\n        cancel_receiver: MpReceiver[TaskId],\n    ):\n        self.event_sender = event_sender\n        self.task_receiver = task_receiver\n        self.cancel_receiver = cancel_receiver\n        self.bound_instance = bound_instance\n\n        self.instance, self.runner_id, self.shard_metadata = (\n            bound_instance.instance,\n            bound_instance.bound_runner_id,\n            bound_instance.bound_shard,\n        )\n        self.device_rank = self.shard_metadata.device_rank\n\n        logger.info(\"hello from the runner\")\n        if getattr(self.shard_metadata, \"immediate_exception\", False):\n            raise Exception(\"Fake exception - runner failed to spin up.\")\n        if timeout := getattr(self.shard_metadata, \"should_timeout\", 0):\n            time.sleep(timeout)\n\n        self.setup_start_time = time.time()\n        self.cancelled_tasks = set[TaskId]()\n\n        self.image_model: DistributedImageModel | None = None\n        self.group = None\n\n        self.current_status: RunnerStatus = RunnerIdle()\n        logger.info(\"runner created\")\n        self.update_status(RunnerIdle())\n        self.seen = set[TaskId]()\n\n    def update_status(self, status: RunnerStatus):\n        self.current_status = status\n        self.event_sender.send(\n            RunnerStatusUpdated(\n                runner_id=self.runner_id, runner_status=self.current_status\n            )\n        )\n\n    def send_task_status(self, task: Task, status: TaskStatus):\n        self.event_sender.send(\n            TaskStatusUpdated(task_id=task.task_id, task_status=status)\n        )\n\n    def acknowledge_task(self, task: Task):\n        self.event_sender.send(TaskAcknowledged(task_id=task.task_id))\n\n    def main(self):\n        with self.task_receiver as tasks:\n            for task in tasks:\n                if task.task_id in self.seen:\n                    logger.warning(\"repeat task - potential error\")\n                self.seen.add(task.task_id)\n                self.cancelled_tasks.discard(CANCEL_ALL_TASKS)\n                self.send_task_status(task, TaskStatus.Running)\n                self.handle_task(task)\n                was_cancelled = (task.task_id in self.cancelled_tasks) or (\n                    CANCEL_ALL_TASKS in self.cancelled_tasks\n                )\n                if not was_cancelled:\n                    self.send_task_status(task, TaskStatus.Complete)\n                self.update_status(self.current_status)\n\n                if isinstance(self.current_status, RunnerShutdown):\n                    break\n\n    def handle_task(self, task: Task):\n        match task:\n            case ConnectToGroup() if isinstance(\n                self.current_status, (RunnerIdle, RunnerFailed)\n            ):\n                logger.info(\"runner connecting\")\n                self.update_status(RunnerConnecting())\n                self.acknowledge_task(task)\n                self.group = initialize_mlx(self.bound_instance)\n\n                logger.info(\"runner connected\")\n                self.current_status = RunnerConnected()\n\n            # we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to\n            case LoadModel() if (\n                isinstance(self.current_status, RunnerConnected)\n                and self.group is not None\n            ) or (isinstance(self.current_status, RunnerIdle) and self.group is None):\n                logger.info(\"runner loading\")\n                self.update_status(RunnerLoading())\n                self.acknowledge_task(task)\n\n                assert (\n                    ModelTask.TextToImage in self.shard_metadata.model_card.tasks\n                    or ModelTask.ImageToImage in self.shard_metadata.model_card.tasks\n                ), f\"Incorrect model task(s): {self.shard_metadata.model_card.tasks}\"\n\n                self.image_model = initialize_image_model(self.bound_instance)\n                self.current_status = RunnerLoaded()\n                logger.info(\"runner loaded\")\n\n            case StartWarmup() if isinstance(self.current_status, RunnerLoaded):\n                logger.info(\"runner warming up\")\n                self.update_status(RunnerWarmingUp())\n                self.acknowledge_task(task)\n\n                logger.info(f\"warming up inference for instance: {self.instance}\")\n\n                assert self.image_model\n                image = warmup_image_generator(model=self.image_model)\n                if image is not None:\n                    logger.info(f\"warmed up by generating {image.size} image\")\n                else:\n                    logger.info(\"warmup completed (non-primary node)\")\n\n                logger.info(\n                    f\"runner initialized in {time.time() - self.setup_start_time} seconds\"\n                )\n\n                self.current_status = RunnerReady()\n                logger.info(\"runner ready\")\n\n            case ImageGeneration(task_params=task_params, command_id=command_id) if (\n                isinstance(self.current_status, RunnerReady)\n            ):\n                assert self.image_model\n                logger.info(f\"received image generation request: {str(task)[:500]}\")\n                logger.info(\"runner running\")\n                self.update_status(RunnerRunning())\n                self.acknowledge_task(task)\n\n                try:\n                    image_index = 0\n                    for response in generate_image(\n                        model=self.image_model, task=task_params\n                    ):\n                        is_primary_output = _is_primary_output_node(self.shard_metadata)\n\n                        if is_primary_output:\n                            match response:\n                                case PartialImageResponse():\n                                    logger.info(\n                                        f\"sending partial ImageChunk {response.partial_index}/{response.total_partials}\"\n                                    )\n                                    _process_image_response(\n                                        response,\n                                        command_id,\n                                        self.shard_metadata,\n                                        self.event_sender,\n                                        image_index,\n                                    )\n                                case ImageGenerationResponse():\n                                    logger.info(\"sending final ImageChunk\")\n                                    _process_image_response(\n                                        response,\n                                        command_id,\n                                        self.shard_metadata,\n                                        self.event_sender,\n                                        image_index,\n                                    )\n                                    image_index += 1\n                # can we make this more explicit?\n                except Exception as e:\n                    if _is_primary_output_node(self.shard_metadata):\n                        self.event_sender.send(\n                            ChunkGenerated(\n                                command_id=command_id,\n                                chunk=ErrorChunk(\n                                    model=self.shard_metadata.model_card.model_id,\n                                    finish_reason=\"error\",\n                                    error_message=str(e),\n                                ),\n                            )\n                        )\n                    raise\n                finally:\n                    _send_traces_if_enabled(\n                        self.event_sender, task.task_id, self.device_rank\n                    )\n\n                self.current_status = RunnerReady()\n                logger.info(\"runner ready\")\n\n            case ImageEdits(task_params=task_params, command_id=command_id) if (\n                isinstance(self.current_status, RunnerReady)\n            ):\n                assert self.image_model\n                logger.info(f\"received image edits request: {str(task)[:500]}\")\n                logger.info(\"runner running\")\n                self.update_status(RunnerRunning())\n                self.acknowledge_task(task)\n\n                try:\n                    image_index = 0\n                    for response in generate_image(\n                        model=self.image_model, task=task_params\n                    ):\n                        if _is_primary_output_node(self.shard_metadata):\n                            match response:\n                                case PartialImageResponse():\n                                    logger.info(\n                                        f\"sending partial ImageChunk {response.partial_index}/{response.total_partials}\"\n                                    )\n                                    _process_image_response(\n                                        response,\n                                        command_id,\n                                        self.shard_metadata,\n                                        self.event_sender,\n                                        image_index,\n                                    )\n                                case ImageGenerationResponse():\n                                    logger.info(\"sending final ImageChunk\")\n                                    _process_image_response(\n                                        response,\n                                        command_id,\n                                        self.shard_metadata,\n                                        self.event_sender,\n                                        image_index,\n                                    )\n                                    image_index += 1\n                except Exception as e:\n                    if _is_primary_output_node(self.shard_metadata):\n                        self.event_sender.send(\n                            ChunkGenerated(\n                                command_id=command_id,\n                                chunk=ErrorChunk(\n                                    model=self.shard_metadata.model_card.model_id,\n                                    finish_reason=\"error\",\n                                    error_message=str(e),\n                                ),\n                            )\n                        )\n                    raise\n                finally:\n                    _send_traces_if_enabled(\n                        self.event_sender, task.task_id, self.device_rank\n                    )\n\n                self.current_status = RunnerReady()\n                logger.info(\"runner ready\")\n\n            case Shutdown():\n                logger.info(\"runner shutting down\")\n                if not TYPE_CHECKING:\n                    del self.image_model, self.group\n                    mx.clear_cache()\n                    import gc\n\n                    gc.collect()\n\n                self.update_status(RunnerShuttingDown())\n                self.acknowledge_task(task)\n\n                self.current_status = RunnerShutdown()\n            case _:\n                raise ValueError(\n                    f\"Received {task.__class__.__name__} outside of state machine in {self.current_status=}\"\n                )\n"
  },
  {
    "path": "src/exo/worker/runner/llm_inference/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/worker/runner/llm_inference/batch_generator.py",
    "content": "import itertools\nimport time\nfrom abc import ABC, abstractmethod\nfrom collections import deque\nfrom collections.abc import Generator, Iterable\nfrom dataclasses import dataclass, field\n\nimport mlx.core as mx\nfrom mlx_lm.tokenizer_utils import TokenizerWrapper\n\nfrom exo.shared.constants import EXO_MAX_CONCURRENT_REQUESTS\nfrom exo.shared.types.chunks import ErrorChunk, PrefillProgressChunk\nfrom exo.shared.types.common import ModelId\nfrom exo.shared.types.events import ChunkGenerated, Event\nfrom exo.shared.types.mlx import Model\nfrom exo.shared.types.tasks import CANCEL_ALL_TASKS, TaskId, TextGeneration\nfrom exo.shared.types.text_generation import TextGenerationTaskParams\nfrom exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse\nfrom exo.utils.channels import MpReceiver, MpSender\nfrom exo.worker.engines.mlx.cache import KVPrefixCache\nfrom exo.worker.engines.mlx.generator.batch_generate import ExoBatchGenerator\nfrom exo.worker.engines.mlx.generator.generate import (\n    PrefillCancelled,\n    mlx_generate,\n    warmup_inference,\n)\nfrom exo.worker.engines.mlx.utils_mlx import (\n    apply_chat_template,\n    mx_all_gather_tasks,\n    mx_any,\n)\nfrom exo.worker.runner.bootstrap import logger\n\nfrom .model_output_parsers import apply_all_parsers\nfrom .tool_parsers import ToolParser\n\n\nclass Cancelled:\n    pass\n\n\nclass Finished:\n    pass\n\n\nclass GeneratorQueue[T]:\n    def __init__(self):\n        self._q = deque[T]()\n\n    def push(self, t: T):\n        self._q.append(t)\n\n    def gen(self) -> Generator[T | None]:\n        while True:\n            if len(self._q) == 0:\n                yield None\n            else:\n                yield self._q.popleft()\n\n\nclass InferenceGenerator(ABC):\n    _cancelled_tasks: set[TaskId]\n\n    def should_cancel(self, task_id: TaskId) -> bool:\n        return (\n            task_id in self._cancelled_tasks\n            or CANCEL_ALL_TASKS in self._cancelled_tasks\n        )\n\n    @abstractmethod\n    def warmup(self) -> None: ...\n\n    @abstractmethod\n    def submit(\n        self,\n        task: TextGeneration,\n    ) -> None: ...\n\n    @abstractmethod\n    def step(\n        self,\n    ) -> Iterable[\n        tuple[TaskId, ToolCallResponse | GenerationResponse | Cancelled | Finished]\n    ]: ...\n\n    @abstractmethod\n    def close(self) -> None: ...\n\n\nEXO_RUNNER_MUST_FAIL = \"EXO RUNNER MUST FAIL\"\nEXO_RUNNER_MUST_OOM = \"EXO RUNNER MUST OOM\"\nEXO_RUNNER_MUST_TIMEOUT = \"EXO RUNNER MUST TIMEOUT\"\n\n\ndef _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:\n    \"\"\"Check for debug prompt triggers in the input.\"\"\"\n    from exo.worker.engines.mlx.utils_mlx import mlx_force_oom\n\n    if len(task_params.input) == 0:\n        return\n    prompt = task_params.input[0].content\n    if not prompt:\n        return\n    if EXO_RUNNER_MUST_FAIL in prompt:\n        raise Exception(\"Artificial runner exception - for testing purposes only.\")\n    if EXO_RUNNER_MUST_OOM in prompt:\n        mlx_force_oom()\n    if EXO_RUNNER_MUST_TIMEOUT in prompt:\n        time.sleep(100)\n\n\n@dataclass(eq=False)\nclass SequentialGenerator(InferenceGenerator):\n    model: Model\n    tokenizer: TokenizerWrapper\n    group: mx.distributed.Group | None\n    kv_prefix_cache: KVPrefixCache | None\n    tool_parser: ToolParser | None\n    model_id: ModelId\n    device_rank: int\n    cancel_receiver: MpReceiver[TaskId]\n    event_sender: MpSender[Event]\n    check_for_cancel_every: int = 50\n\n    _cancelled_tasks: set[TaskId] = field(default_factory=set, init=False)\n    _maybe_queue: list[TextGeneration] = field(default_factory=list, init=False)\n    _maybe_cancel: list[TextGeneration] = field(default_factory=list, init=False)\n    _all_tasks: dict[TaskId, TextGeneration] = field(default_factory=dict, init=False)\n    _queue: deque[TextGeneration] = field(default_factory=deque, init=False)\n    _active: (\n        tuple[\n            TextGeneration,\n            # mlx generator that does work\n            Generator[GenerationResponse],\n            # queue that the 1st generator should push to and 3rd generator should pull from\n            GeneratorQueue[GenerationResponse],\n            # generator to get parsed outputs\n            Generator[GenerationResponse | ToolCallResponse | None],\n        ]\n        | None\n    ) = field(default=None, init=False)\n\n    def warmup(self):\n        self.check_for_cancel_every = warmup_inference(\n            model=self.model,\n            tokenizer=self.tokenizer,\n            group=self.group,\n            model_id=self.model_id,\n        )\n\n    def submit(\n        self,\n        task: TextGeneration,\n    ) -> None:\n        self._cancelled_tasks.discard(CANCEL_ALL_TASKS)\n        self._all_tasks[task.task_id] = task\n        self._maybe_queue.append(task)\n\n    def agree_on_tasks(self) -> None:\n        \"\"\"Agree between all ranks about the task ordering (some may have received in different order or not at all).\"\"\"\n        agreed, different = mx_all_gather_tasks(self._maybe_queue, self.group)\n        self._queue.extend(task for task in self._maybe_queue if task in agreed)\n        self._maybe_queue = [task for task in self._maybe_queue if task in different]\n\n    def agree_on_cancellations(self) -> None:\n        \"\"\"Agree between all ranks about which tasks to cancel.\"\"\"\n        has_cancel_all = False\n        for task_id in self.cancel_receiver.collect():\n            if task_id == CANCEL_ALL_TASKS:\n                has_cancel_all = True\n                continue\n            if task_id in self._all_tasks:\n                self._maybe_cancel.append(self._all_tasks[task_id])\n\n        if mx_any(has_cancel_all, self.group):\n            self._cancelled_tasks.add(CANCEL_ALL_TASKS)\n\n        agreed, different = mx_all_gather_tasks(self._maybe_cancel, self.group)\n        self._cancelled_tasks.update(task.task_id for task in agreed)\n        self._maybe_cancel = list(different)\n\n    def step(\n        self,\n    ) -> Iterable[\n        tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]\n    ]:\n        if self._active is None:\n            self.agree_on_tasks()\n\n            if self._queue:\n                self._start_next()\n            else:\n                return map(lambda task: (task, Cancelled()), self._cancelled_tasks)\n\n        assert self._active is not None\n\n        task, mlx_gen, queue, output_generator = self._active\n        response = None\n        try:\n            queue.push(next(mlx_gen))\n            response = next(output_generator)\n        except (StopIteration, PrefillCancelled):\n            response = Finished()\n            self._active = None\n            if self._queue:\n                self._start_next()\n        except Exception as e:\n            self._send_error(task, e)\n            self._active = None\n            raise\n        return itertools.chain(\n            [] if response is None else [(task.task_id, response)],\n            map(lambda task: (task, Cancelled()), self._cancelled_tasks),\n        )\n\n    def _start_next(self) -> None:\n        task = self._queue.popleft()\n        try:\n            mlx_gen = self._build_generator(task)\n        except Exception as e:\n            self._send_error(task, e)\n            raise\n        queue = GeneratorQueue[GenerationResponse]()\n\n        if task.task_params.bench:\n            output_generator = queue.gen()\n        else:\n            output_generator = apply_all_parsers(\n                queue.gen(),\n                apply_chat_template(self.tokenizer, task.task_params),\n                self.tool_parser,\n                self.tokenizer,\n                type(self.model),\n                self.model_id,\n                task.task_params.tools,\n            )\n        self._active = (task, mlx_gen, queue, output_generator)\n\n    def _send_error(self, task: TextGeneration, e: Exception) -> None:\n        if self.device_rank == 0:\n            self.event_sender.send(\n                ChunkGenerated(\n                    command_id=task.command_id,\n                    chunk=ErrorChunk(\n                        model=self.model_id,\n                        finish_reason=\"error\",\n                        error_message=str(e),\n                    ),\n                )\n            )\n\n    def _build_generator(self, task: TextGeneration) -> Generator[GenerationResponse]:\n        _check_for_debug_prompts(task.task_params)\n        prompt = apply_chat_template(self.tokenizer, task.task_params)\n\n        def on_prefill_progress(processed: int, total: int) -> None:\n            if self.device_rank == 0:\n                self.event_sender.send(\n                    ChunkGenerated(\n                        command_id=task.command_id,\n                        chunk=PrefillProgressChunk(\n                            model=self.model_id,\n                            processed_tokens=processed,\n                            total_tokens=total,\n                        ),\n                    )\n                )\n\n        def distributed_prompt_progress_callback() -> None:\n            self.agree_on_cancellations()\n            if self.should_cancel(task.task_id):\n                raise PrefillCancelled()\n\n            self.agree_on_tasks()\n\n        tokens_since_cancel_check = self.check_for_cancel_every\n\n        def on_generation_token() -> None:\n            nonlocal tokens_since_cancel_check\n            tokens_since_cancel_check += 1\n            if tokens_since_cancel_check >= self.check_for_cancel_every:\n                tokens_since_cancel_check = 0\n                self.agree_on_cancellations()\n                if self.should_cancel(task.task_id):\n                    raise PrefillCancelled()\n\n                self.agree_on_tasks()\n\n        return mlx_generate(\n            model=self.model,\n            tokenizer=self.tokenizer,\n            task=task.task_params,\n            prompt=prompt,\n            kv_prefix_cache=self.kv_prefix_cache,\n            on_prefill_progress=on_prefill_progress,\n            distributed_prompt_progress_callback=distributed_prompt_progress_callback,\n            on_generation_token=on_generation_token,\n            group=self.group,\n        )\n\n    def close(self) -> None:\n        del self.model, self.tokenizer, self.group\n\n\n@dataclass(eq=False)\nclass BatchGenerator(InferenceGenerator):\n    model: Model\n    tokenizer: TokenizerWrapper\n    group: mx.distributed.Group | None\n    kv_prefix_cache: KVPrefixCache | None\n    tool_parser: ToolParser | None\n    model_id: ModelId\n    device_rank: int\n    cancel_receiver: MpReceiver[TaskId]\n    event_sender: MpSender[Event]\n    check_for_cancel_every: int = 50\n\n    _cancelled_tasks: set[TaskId] = field(default_factory=set, init=False)\n    _maybe_queue: list[TextGeneration] = field(default_factory=list, init=False)\n    _maybe_cancel: list[TextGeneration] = field(default_factory=list, init=False)\n    _all_tasks: dict[TaskId, TextGeneration] = field(default_factory=dict, init=False)\n    _queue: deque[TextGeneration] = field(default_factory=deque, init=False)\n    _mlx_gen: ExoBatchGenerator = field(init=False)\n    _active_tasks: dict[\n        int,\n        tuple[\n            TextGeneration,\n            GeneratorQueue[GenerationResponse],\n            Generator[GenerationResponse | ToolCallResponse | None],\n        ],\n    ] = field(default_factory=dict, init=False)\n\n    def __post_init__(self) -> None:\n        self._mlx_gen = ExoBatchGenerator(\n            model=self.model,\n            tokenizer=self.tokenizer,\n            group=self.group,\n            kv_prefix_cache=self.kv_prefix_cache,\n        )\n\n    def warmup(self):\n        self.check_for_cancel_every = warmup_inference(\n            model=self.model,\n            tokenizer=self.tokenizer,\n            group=self.group,\n            model_id=self.model_id,\n        )\n\n    def submit(\n        self,\n        task: TextGeneration,\n    ) -> None:\n        self._cancelled_tasks.discard(CANCEL_ALL_TASKS)\n        self._all_tasks[task.task_id] = task\n        self._maybe_queue.append(task)\n\n    def agree_on_tasks(self) -> None:\n        \"\"\"Agree between all ranks about the task ordering (some may have received in different order or not at all).\"\"\"\n        agreed, different = mx_all_gather_tasks(self._maybe_queue, self.group)\n        self._queue.extend(task for task in self._maybe_queue if task in agreed)\n        self._maybe_queue = [task for task in self._maybe_queue if task in different]\n\n    def agree_on_cancellations(self) -> None:\n        \"\"\"Agree between all ranks about which tasks to cancel.\"\"\"\n        has_cancel_all = False\n        for task_id in self.cancel_receiver.collect():\n            if task_id == CANCEL_ALL_TASKS:\n                has_cancel_all = True\n                continue\n            if task_id in self._all_tasks:\n                self._maybe_cancel.append(self._all_tasks[task_id])\n\n        if mx_any(has_cancel_all, self.group):\n            self._cancelled_tasks.add(CANCEL_ALL_TASKS)\n\n        agreed, different = mx_all_gather_tasks(self._maybe_cancel, self.group)\n        self._cancelled_tasks.update(task.task_id for task in agreed)\n        self._maybe_cancel = list(different)\n\n    def step(\n        self,\n    ) -> Iterable[\n        tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]\n    ]:\n        if not self._queue:\n            self.agree_on_tasks()\n\n        # Submit any queued tasks to the engine\n        while self._queue and len(self._active_tasks) < EXO_MAX_CONCURRENT_REQUESTS:\n            task = self._queue.popleft()\n            try:\n                uid = self._start_task(task)\n            except PrefillCancelled:\n                continue\n            except Exception as e:\n                self._send_error(task, e)\n                raise\n\n            queue = GeneratorQueue[GenerationResponse]()\n            if task.task_params.bench:\n                output_generator = queue.gen()\n            else:\n                output_generator = apply_all_parsers(\n                    queue.gen(),\n                    apply_chat_template(self.tokenizer, task.task_params),\n                    self.tool_parser,\n                    self.tokenizer,\n                    type(self.model),\n                    self.model_id,\n                    task.task_params.tools,\n                )\n            self._active_tasks[uid] = (task, queue, output_generator)\n\n        if not self._mlx_gen.has_work:\n            return self._apply_cancellations()\n\n        results = self._mlx_gen.step()\n\n        output: list[\n            tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]\n        ] = []\n        for uid, response in results:\n            if uid not in self._active_tasks:\n                # should we error here?\n                logger.warning(f\"{uid=} not found in active tasks\")\n                continue\n\n            task, queue, output_generator = self._active_tasks[uid]\n            queue.push(response)\n            parsed = next(output_generator)\n\n            if parsed is not None:\n                output.append((task.task_id, parsed))\n\n            if response.finish_reason is not None:\n                output.append((task.task_id, Finished()))\n                del self._active_tasks[uid]\n\n        return itertools.chain(output, self._apply_cancellations())\n\n    def _apply_cancellations(\n        self,\n    ) -> list[tuple[TaskId, Cancelled]]:\n        if not self._cancelled_tasks:\n            return []\n\n        cancel_all = CANCEL_ALL_TASKS in self._cancelled_tasks\n\n        uids_to_cancel: list[int] = []\n        results: list[tuple[TaskId, Cancelled]] = []\n\n        for uid, (task, _, _) in list(self._active_tasks.items()):\n            if task.task_id in self._cancelled_tasks or cancel_all:\n                uids_to_cancel.append(uid)\n                results.append((task.task_id, Cancelled()))\n                del self._active_tasks[uid]\n\n        if uids_to_cancel:\n            self._mlx_gen.cancel(uids_to_cancel)\n\n        already_cancelled = {tid for tid, _ in results}\n        for tid in self._cancelled_tasks:\n            if tid != CANCEL_ALL_TASKS and tid not in already_cancelled:\n                results.append((tid, Cancelled()))\n\n        self._cancelled_tasks.clear()\n        return results\n\n    def _send_error(self, task: TextGeneration, e: Exception) -> None:\n        if self.device_rank == 0:\n            self.event_sender.send(\n                ChunkGenerated(\n                    command_id=task.command_id,\n                    chunk=ErrorChunk(\n                        model=self.model_id,\n                        finish_reason=\"error\",\n                        error_message=str(e),\n                    ),\n                )\n            )\n\n    def _start_task(self, task: TextGeneration) -> int:\n        _check_for_debug_prompts(task.task_params)\n        prompt = apply_chat_template(self.tokenizer, task.task_params)\n\n        def on_prefill_progress(processed: int, total: int) -> None:\n            if self.device_rank == 0:\n                self.event_sender.send(\n                    ChunkGenerated(\n                        command_id=task.command_id,\n                        chunk=PrefillProgressChunk(\n                            model=self.model_id,\n                            processed_tokens=processed,\n                            total_tokens=total,\n                        ),\n                    )\n                )\n\n        def distributed_prompt_progress_callback() -> None:\n            self.agree_on_cancellations()\n            if self.should_cancel(task.task_id):\n                raise PrefillCancelled()\n\n            self.agree_on_tasks()\n\n        tokens_since_cancel_check = self.check_for_cancel_every\n\n        def on_generation_token() -> None:\n            nonlocal tokens_since_cancel_check\n            tokens_since_cancel_check += 1\n            if tokens_since_cancel_check >= self.check_for_cancel_every:\n                tokens_since_cancel_check = 0\n                self.agree_on_cancellations()\n                if self.should_cancel(task.task_id):\n                    self._cancelled_tasks.add(task.task_id)\n\n                self.agree_on_tasks()\n\n        return self._mlx_gen.submit(\n            task_params=task.task_params,\n            prompt=prompt,\n            on_prefill_progress=on_prefill_progress,\n            distributed_prompt_progress_callback=distributed_prompt_progress_callback,\n            on_generation_token=on_generation_token,\n        )\n\n    def close(self) -> None:\n        self._mlx_gen.close()\n        del self.model, self.tokenizer, self.group\n"
  },
  {
    "path": "src/exo/worker/runner/llm_inference/model_output_parsers.py",
    "content": "from collections.abc import Generator\nfrom functools import cache\nfrom typing import Any\n\nfrom mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model\nfrom mlx_lm.models.gpt_oss import Model as GptOssModel\nfrom mlx_lm.tokenizer_utils import TokenizerWrapper\nfrom openai_harmony import (  # pyright: ignore[reportMissingTypeStubs]\n    HarmonyEncodingName,\n    HarmonyError,  # pyright: ignore[reportUnknownVariableType]\n    Role,\n    StreamableParser,\n    load_harmony_encoding,\n)\n\nfrom exo.api.types import ToolCallItem\nfrom exo.shared.types.common import ModelId\nfrom exo.shared.types.mlx import Model\nfrom exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse\nfrom exo.worker.engines.mlx.utils_mlx import (\n    detect_thinking_prompt_suffix,\n)\nfrom exo.worker.runner.bootstrap import logger\nfrom exo.worker.runner.llm_inference.tool_parsers import ToolParser\n\n\n@cache\ndef get_gpt_oss_encoding():\n    encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)\n    return encoding\n\n\ndef apply_all_parsers(\n    receiver: Generator[GenerationResponse | None],\n    prompt: str,\n    tool_parser: ToolParser | None,\n    tokenizer: TokenizerWrapper,\n    model_type: type[Model],\n    model_id: ModelId,\n    tools: list[dict[str, Any]] | None,\n) -> Generator[GenerationResponse | ToolCallResponse | None]:\n    mlx_generator = receiver\n\n    if tokenizer.has_thinking:\n        mlx_generator = parse_thinking_models(\n            mlx_generator,\n            tokenizer.think_start,\n            tokenizer.think_end,\n            starts_in_thinking=detect_thinking_prompt_suffix(prompt, tokenizer),\n        )\n\n    if issubclass(model_type, GptOssModel):\n        mlx_generator = parse_gpt_oss(mlx_generator)\n    elif (\n        issubclass(model_type, DeepseekV32Model)\n        and \"deepseek\" in model_id.normalize().lower()\n    ):\n        mlx_generator = parse_deepseek_v32(mlx_generator)\n    elif tool_parser:\n        mlx_generator = parse_tool_calls(mlx_generator, tool_parser, tools)\n\n    return mlx_generator\n\n\ndef parse_gpt_oss(\n    responses: Generator[GenerationResponse | None],\n) -> Generator[GenerationResponse | ToolCallResponse | None]:\n    encoding = get_gpt_oss_encoding()\n    stream = StreamableParser(encoding, role=Role.ASSISTANT)\n    thinking = False\n    current_tool_name: str | None = None\n    tool_arg_parts: list[str] = []\n\n    for response in responses:\n        if response is None:\n            yield None\n            continue\n        try:\n            stream.process(response.token)\n        except HarmonyError:\n            logger.error(\"Encountered critical Harmony Error, returning early\")\n            return\n\n        delta = stream.last_content_delta\n        ch = stream.current_channel\n        recipient = stream.current_recipient\n\n        # Debug: log every token with state\n        logger.debug(\n            f\"parse_gpt_oss token={response.token} text={response.text!r} \"\n            f\"recipient={recipient!r} ch={ch!r} delta={delta!r} \"\n            f\"state={stream.state} current_tool={current_tool_name!r}\"\n        )\n\n        if recipient != current_tool_name:\n            if current_tool_name is not None:\n                prefix = \"functions.\"\n                if current_tool_name.startswith(prefix):\n                    current_tool_name = current_tool_name[len(prefix) :]\n                logger.info(\n                    f\"parse_gpt_oss yielding tool call: name={current_tool_name!r}\"\n                )\n                yield ToolCallResponse(\n                    tool_calls=[\n                        ToolCallItem(\n                            name=current_tool_name,\n                            arguments=\"\".join(tool_arg_parts).strip(),\n                        )\n                    ],\n                    usage=response.usage,\n                )\n                tool_arg_parts = []\n            current_tool_name = recipient\n\n        # If inside a tool call, accumulate arguments\n        if current_tool_name is not None:\n            if delta:\n                tool_arg_parts.append(delta)\n            if response.finish_reason is not None:\n                yield response.model_copy(update={\"text\": \"\".join(tool_arg_parts)})\n                tool_arg_parts = []\n            continue\n\n        if ch == \"analysis\" and not thinking:\n            thinking = True\n\n        if ch != \"analysis\" and thinking:\n            thinking = False\n\n        if delta:\n            yield response.model_copy(update={\"text\": delta, \"is_thinking\": thinking})\n\n        if response.finish_reason is not None:\n            yield response\n\n\ndef parse_deepseek_v32(\n    responses: Generator[GenerationResponse | None],\n) -> Generator[GenerationResponse | ToolCallResponse | None]:\n    \"\"\"Parse DeepSeek V3.2 DSML tool calls from the generation stream.\n\n    Uses accumulated-text matching (not per-token marker checks) because\n    DSML markers like <｜DSML｜function_calls> may span multiple tokens.\n    Also handles <think>...</think> blocks for thinking mode.\n    \"\"\"\n    from exo.worker.engines.mlx.dsml_encoding import (\n        THINKING_END,\n        THINKING_START,\n        TOOL_CALLS_END,\n        TOOL_CALLS_START,\n        parse_dsml_output,\n    )\n\n    accumulated = \"\"\n    in_tool_call = False\n    thinking = False\n    # Tokens buffered while we detect the start of a DSML block\n    pending_buffer: list[GenerationResponse] = []\n    # Text accumulated during a tool call block\n    tool_call_text = \"\"\n\n    for response in responses:\n        if response is None:\n            yield None\n            continue\n\n        # ── Handle thinking tags ──\n        if not thinking and THINKING_START in response.text:\n            thinking = True\n            # Yield any text before the <think> tag\n            before = response.text[: response.text.index(THINKING_START)]\n            if before:\n                yield response.model_copy(update={\"text\": before})\n            continue\n\n        if thinking and THINKING_END in response.text:\n            thinking = False\n            # Yield any text after the </think> tag\n            after = response.text[\n                response.text.index(THINKING_END) + len(THINKING_END) :\n            ]\n            if after:\n                yield response.model_copy(update={\"text\": after, \"is_thinking\": False})\n            continue\n\n        if thinking:\n            yield response.model_copy(update={\"is_thinking\": True})\n            continue\n\n        # ── Handle tool call accumulation ──\n        if in_tool_call:\n            tool_call_text += response.text\n            if TOOL_CALLS_END in tool_call_text:\n                # Parse the accumulated DSML block\n                parsed = parse_dsml_output(tool_call_text)\n                if parsed is not None:\n                    logger.info(f\"parsed DSML tool calls: {parsed}\")\n                    yield ToolCallResponse(\n                        tool_calls=parsed,\n                        usage=response.usage,\n                        stats=response.stats,\n                    )\n                else:\n                    logger.warning(\n                        f\"DSML tool call parsing failed for: {tool_call_text}\"\n                    )\n                    yield response.model_copy(update={\"text\": tool_call_text})\n                in_tool_call = False\n                tool_call_text = \"\"\n                continue\n\n            # EOS reached before end marker — yield buffered text as-is\n            if response.finish_reason is not None:\n                logger.info(\"DSML tool call parsing interrupted by EOS\")\n                yield response.model_copy(update={\"text\": tool_call_text})\n                in_tool_call = False\n                tool_call_text = \"\"\n            continue\n\n        # ── Detect start of tool call block ──\n        accumulated += response.text\n\n        if TOOL_CALLS_START in accumulated:\n            # The start marker might be split across pending_buffer + current token\n            start_idx = accumulated.index(TOOL_CALLS_START)\n            # Yield any pending tokens that are purely before the marker\n            pre_text = accumulated[:start_idx]\n            if pre_text:\n                # Flush pending buffer tokens that contributed text before the marker\n                for buf_resp in pending_buffer:\n                    if pre_text:\n                        chunk = buf_resp.text\n                        if len(chunk) <= len(pre_text):\n                            yield buf_resp\n                            pre_text = pre_text[len(chunk) :]\n                        else:\n                            yield buf_resp.model_copy(update={\"text\": pre_text})\n                            pre_text = \"\"\n            pending_buffer = []\n            tool_call_text = accumulated[start_idx:]\n            accumulated = \"\"\n\n            # Check if the end marker is already present (entire tool call in one token)\n            if TOOL_CALLS_END in tool_call_text:\n                parsed = parse_dsml_output(tool_call_text)\n                if parsed is not None:\n                    logger.info(f\"parsed DSML tool calls: {parsed}\")\n                    yield ToolCallResponse(\n                        tool_calls=parsed,\n                        usage=response.usage,\n                        stats=response.stats,\n                    )\n                else:\n                    logger.warning(\n                        f\"DSML tool call parsing failed for: {tool_call_text}\"\n                    )\n                    yield response.model_copy(update={\"text\": tool_call_text})\n                tool_call_text = \"\"\n            else:\n                in_tool_call = True\n            continue\n\n        # Check if accumulated text might be the start of a DSML marker\n        # Buffer tokens if we see a partial match at the end\n        if _could_be_dsml_prefix(accumulated):\n            pending_buffer.append(response)\n            continue\n\n        # No partial match — flush all pending tokens and the current one\n        for buf_resp in pending_buffer:\n            yield buf_resp\n        pending_buffer = []\n        accumulated = \"\"\n        yield response\n\n    # Flush any remaining pending buffer at generator end\n    for buf_resp in pending_buffer:\n        yield buf_resp\n\n\ndef _could_be_dsml_prefix(text: str) -> bool:\n    \"\"\"Check if the end of text could be the start of a DSML function_calls marker.\n\n    We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.\n    This allows us to buffer tokens until we can determine if a tool call is starting.\n    \"\"\"\n    from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START\n\n    # Only check the last portion of text that could overlap with the marker\n    max_check = len(TOOL_CALLS_START)\n    tail = text[-max_check:] if len(text) > max_check else text\n\n    # Check if any suffix of tail is a prefix of TOOL_CALLS_START\n    for i in range(len(tail)):\n        suffix = tail[i:]\n        if TOOL_CALLS_START.startswith(suffix):\n            return True\n    return False\n\n\ndef parse_thinking_models(\n    responses: Generator[GenerationResponse | None],\n    think_start: str | None,\n    think_end: str | None,\n    starts_in_thinking: bool = True,\n) -> Generator[GenerationResponse | None]:\n    \"\"\"Route thinking tokens via is_thinking flag.\n\n    Swallows think tag tokens, sets is_thinking on all others.\n    Always yields tokens with finish_reason to avoid hanging the chunk stream.\n    \"\"\"\n    is_thinking = starts_in_thinking\n    for response in responses:\n        if response is None:\n            yield None\n            continue\n        if response.finish_reason is not None:\n            yield response.model_copy(update={\"is_thinking\": False})\n            continue\n\n        if response.text == think_start:\n            is_thinking = True\n            continue\n        if response.text == think_end:\n            is_thinking = False\n            continue\n\n        yield response.model_copy(update={\"is_thinking\": is_thinking})\n\n\ndef parse_tool_calls(\n    responses: Generator[GenerationResponse | None],\n    tool_parser: ToolParser,\n    tools: list[dict[str, Any]] | None,\n) -> Generator[GenerationResponse | ToolCallResponse | None]:\n    in_tool_call = False\n    tool_call_text_parts: list[str] = []\n    for response in responses:\n        if response is None:\n            yield None\n            continue\n\n        if not in_tool_call and response.text.startswith(tool_parser.start_parsing):\n            in_tool_call = True\n\n        if not in_tool_call:\n            yield response\n            continue\n\n        tool_call_text_parts.append(response.text)\n        if response.text.endswith(tool_parser.end_parsing):\n            # parse the actual tool calls from the tool call text\n            combined = \"\".join(tool_call_text_parts)\n            parsed = tool_parser.parse(combined.strip(), tools=tools)\n            logger.info(f\"parsed {tool_call_text_parts=} into {parsed=}\")\n            in_tool_call = False\n            tool_call_text_parts = []\n\n            if parsed is None:\n                logger.warning(f\"tool call parsing failed for text {combined}\")\n                yield response.model_copy(update={\"text\": combined})\n                continue\n\n            yield ToolCallResponse(\n                tool_calls=parsed, usage=response.usage, stats=response.stats\n            )\n            continue\n\n        if response.finish_reason is not None:\n            logger.info(\n                \"tool call parsing interrupted, yield partial tool call as text\"\n            )\n            response = response.model_copy(\n                update={\n                    \"text\": \"\".join(tool_call_text_parts),\n                    \"token\": 0,\n                }\n            )\n            yield response\n"
  },
  {
    "path": "src/exo/worker/runner/llm_inference/runner.py",
    "content": "import os\nimport time\nfrom dataclasses import dataclass\nfrom enum import Enum\n\nimport mlx.core as mx\nfrom anyio import WouldBlock\nfrom mlx_lm.tokenizer_utils import TokenizerWrapper\n\nfrom exo.shared.models.model_cards import ModelTask\nfrom exo.shared.types.chunks import (\n    ErrorChunk,\n    TokenChunk,\n    ToolCallChunk,\n)\nfrom exo.shared.types.common import CommandId, ModelId\nfrom exo.shared.types.events import (\n    ChunkGenerated,\n    Event,\n    RunnerStatusUpdated,\n    TaskAcknowledged,\n    TaskStatusUpdated,\n)\nfrom exo.shared.types.mlx import Model\nfrom exo.shared.types.tasks import (\n    ConnectToGroup,\n    LoadModel,\n    Shutdown,\n    StartWarmup,\n    Task,\n    TaskId,\n    TaskStatus,\n    TextGeneration,\n)\nfrom exo.shared.types.worker.instances import BoundInstance\nfrom exo.shared.types.worker.runner_response import (\n    GenerationResponse,\n    ToolCallResponse,\n)\nfrom exo.shared.types.worker.runners import (\n    RunnerConnected,\n    RunnerConnecting,\n    RunnerFailed,\n    RunnerIdle,\n    RunnerLoaded,\n    RunnerLoading,\n    RunnerReady,\n    RunnerRunning,\n    RunnerShutdown,\n    RunnerShuttingDown,\n    RunnerStatus,\n    RunnerWarmingUp,\n)\nfrom exo.utils.channels import MpReceiver, MpSender\nfrom exo.worker.engines.mlx.cache import KVPrefixCache\nfrom exo.worker.engines.mlx.utils_mlx import (\n    initialize_mlx,\n    load_mlx_items,\n)\nfrom exo.worker.runner.bootstrap import logger\nfrom exo.worker.runner.llm_inference.batch_generator import (\n    BatchGenerator,\n    InferenceGenerator,\n    SequentialGenerator,\n)\n\nfrom .batch_generator import Cancelled, Finished\nfrom .tool_parsers import make_mlx_parser\n\n\nclass ExitCode(str, Enum):\n    AllTasksComplete = \"AllTasksComplete\"\n    Shutdown = \"Shutdown\"\n\n\nclass Runner:\n    def __init__(\n        self,\n        bound_instance: BoundInstance,\n        event_sender: MpSender[Event],\n        task_receiver: MpReceiver[Task],\n        cancel_receiver: MpReceiver[TaskId],\n    ):\n        self.event_sender = event_sender\n        self.task_receiver = task_receiver\n        self.cancel_receiver = cancel_receiver\n        self.bound_instance = bound_instance\n\n        self.instance, self.runner_id, self.shard_metadata = (\n            self.bound_instance.instance,\n            self.bound_instance.bound_runner_id,\n            self.bound_instance.bound_shard,\n        )\n        self.model_id = self.shard_metadata.model_card.model_id\n        self.device_rank = self.shard_metadata.device_rank\n\n        logger.info(\"hello from the runner\")\n        if getattr(self.shard_metadata, \"immediate_exception\", False):\n            raise Exception(\"Fake exception - runner failed to spin up.\")\n        if timeout := getattr(self.shard_metadata, \"should_timeout\", 0):\n            time.sleep(timeout)\n\n        self.setup_start_time = time.time()\n\n        self.generator: Builder | InferenceGenerator = Builder(\n            self.model_id, self.event_sender, self.cancel_receiver\n        )\n\n        self.seen: set[TaskId] = set()\n        self.active_tasks: dict[\n            TaskId,\n            TextGeneration,\n        ] = {}\n\n        logger.info(\"runner created\")\n        self.update_status(RunnerIdle())\n\n    def update_status(self, status: RunnerStatus):\n        self.current_status = status\n        self.event_sender.send(\n            RunnerStatusUpdated(\n                runner_id=self.runner_id, runner_status=self.current_status\n            )\n        )\n\n    def send_task_status(self, task_id: TaskId, task_status: TaskStatus):\n        self.event_sender.send(\n            TaskStatusUpdated(task_id=task_id, task_status=task_status)\n        )\n\n    def acknowledge_task(self, task: Task):\n        self.event_sender.send(TaskAcknowledged(task_id=task.task_id))\n\n    def main(self):\n        with self.task_receiver:\n            for task in self.task_receiver:\n                if task.task_id in self.seen:\n                    logger.warning(\"repeat task - potential error\")\n                    continue\n                self.seen.add(task.task_id)\n                self.handle_first_task(task)\n                if isinstance(self.current_status, RunnerShutdown):\n                    break\n\n    def handle_first_task(self, task: Task):\n        self.send_task_status(task.task_id, TaskStatus.Running)\n\n        match task:\n            case ConnectToGroup() if isinstance(\n                self.current_status, (RunnerIdle, RunnerFailed)\n            ):\n                assert isinstance(self.generator, Builder)\n                logger.info(\"runner connecting\")\n                self.update_status(RunnerConnecting())\n                self.acknowledge_task(task)\n\n                self.generator.group = initialize_mlx(self.bound_instance)\n\n                self.send_task_status(task.task_id, TaskStatus.Complete)\n                self.update_status(RunnerConnected())\n                logger.info(\"runner connected\")\n\n            # we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to\n            case LoadModel() if isinstance(self.generator, Builder) and (\n                (\n                    isinstance(self.current_status, RunnerConnected)\n                    and self.generator.group is not None\n                )\n                or (\n                    isinstance(self.current_status, RunnerIdle)\n                    and self.generator.group is None\n                )\n            ):\n                total_layers = (\n                    self.shard_metadata.end_layer - self.shard_metadata.start_layer\n                )\n                logger.info(\"runner loading\")\n\n                self.update_status(\n                    RunnerLoading(layers_loaded=0, total_layers=total_layers)\n                )\n                self.acknowledge_task(task)\n\n                def on_model_load_timeout() -> None:\n                    self.update_status(\n                        RunnerFailed(error_message=\"Model loading timed out\")\n                    )\n                    time.sleep(0.5)\n\n                def on_layer_loaded(layers_loaded: int, total: int) -> None:\n                    self.update_status(\n                        RunnerLoading(layers_loaded=layers_loaded, total_layers=total)\n                    )\n\n                assert (\n                    ModelTask.TextGeneration in self.shard_metadata.model_card.tasks\n                ), f\"Incorrect model task(s): {self.shard_metadata.model_card.tasks}\"\n                self.generator.inference_model, self.generator.tokenizer = (\n                    load_mlx_items(\n                        self.bound_instance,\n                        self.generator.group,\n                        on_timeout=on_model_load_timeout,\n                        on_layer_loaded=on_layer_loaded,\n                    )\n                )\n\n                self.generator = self.generator.build()\n\n                self.send_task_status(task.task_id, TaskStatus.Complete)\n                self.update_status(RunnerLoaded())\n                logger.info(\"runner loaded\")\n\n            case StartWarmup() if isinstance(self.current_status, RunnerLoaded):\n                assert isinstance(self.generator, InferenceGenerator)\n                logger.info(\"runner warming up\")\n\n                self.update_status(RunnerWarmingUp())\n                self.acknowledge_task(task)\n\n                self.generator.warmup()\n\n                logger.info(\n                    f\"runner initialized in {time.time() - self.setup_start_time} seconds\"\n                )\n\n                self.send_task_status(task.task_id, TaskStatus.Complete)\n                self.update_status(RunnerReady())\n                logger.info(\"runner ready\")\n\n            case TextGeneration() if isinstance(self.current_status, RunnerReady):\n                return_code = self.handle_generation_tasks(starting_task=task)\n                if return_code == ExitCode.Shutdown:\n                    return\n\n            case Shutdown():\n                self.shutdown(task)\n                return\n\n            case _:\n                raise ValueError(\n                    f\"Received {task.__class__.__name__} outside of state machine in {self.current_status=}\"\n                )\n\n    def shutdown(self, task: Task):\n        logger.info(\"runner shutting down\")\n        self.update_status(RunnerShuttingDown())\n        self.acknowledge_task(task)\n        if isinstance(self.generator, InferenceGenerator):\n            self.generator.close()\n        mx.clear_cache()\n        import gc\n\n        gc.collect()\n        self.send_task_status(task.task_id, TaskStatus.Complete)\n        self.update_status(RunnerShutdown())\n\n    def submit_text_generation(self, task: TextGeneration):\n        assert isinstance(self.generator, InferenceGenerator)\n        self.active_tasks[task.task_id] = task\n        self.generator.submit(task)\n\n    def handle_generation_tasks(self, starting_task: TextGeneration):\n        assert isinstance(self.current_status, RunnerReady)\n        assert isinstance(self.generator, InferenceGenerator)\n\n        logger.info(f\"received chat request: {starting_task}\")\n        self.update_status(RunnerRunning())\n        logger.info(\"runner running\")\n        self.acknowledge_task(starting_task)\n        self.seen.add(starting_task.task_id)\n\n        self.submit_text_generation(starting_task)\n\n        while self.active_tasks:\n            results = self.generator.step()\n\n            finished: list[TaskId] = []\n            for task_id, result in results:\n                match result:\n                    case Cancelled():\n                        finished.append(task_id)\n                    case Finished():\n                        self.send_task_status(task_id, TaskStatus.Complete)\n                        finished.append(task_id)\n                    case _:\n                        self.send_response(\n                            result, self.active_tasks[task_id].command_id\n                        )\n\n            for task_id in finished:\n                self.active_tasks.pop(task_id, None)\n\n            try:\n                task = self.task_receiver.receive_nowait()\n\n                if task.task_id in self.seen:\n                    logger.warning(\"repeat task - potential error\")\n                    continue\n                self.seen.add(task.task_id)\n\n                match task:\n                    case TextGeneration():\n                        self.acknowledge_task(task)\n                        self.submit_text_generation(task)\n                    case Shutdown():\n                        self.shutdown(task)\n                        return ExitCode.Shutdown\n                    case _:\n                        raise ValueError(\n                            f\"Received {task.__class__.__name__} outside of state machine in {self.current_status=}\"\n                        )\n\n            except WouldBlock:\n                pass\n\n        self.update_status(RunnerReady())\n        logger.info(\"runner ready\")\n\n        return ExitCode.AllTasksComplete\n\n    def send_response(\n        self, response: GenerationResponse | ToolCallResponse, command_id: CommandId\n    ):\n        match response:\n            case GenerationResponse():\n                if self.device_rank == 0 and response.finish_reason == \"error\":\n                    self.event_sender.send(\n                        ChunkGenerated(\n                            command_id=command_id,\n                            chunk=ErrorChunk(\n                                error_message=response.text,\n                                model=self.model_id,\n                            ),\n                        )\n                    )\n\n                elif self.device_rank == 0:\n                    assert response.finish_reason not in (\n                        \"error\",\n                        \"tool_calls\",\n                        \"function_call\",\n                    )\n                    self.event_sender.send(\n                        ChunkGenerated(\n                            command_id=command_id,\n                            chunk=TokenChunk(\n                                model=self.model_id,\n                                text=response.text,\n                                token_id=response.token,\n                                usage=response.usage,\n                                finish_reason=response.finish_reason,\n                                stats=response.stats,\n                                logprob=response.logprob,\n                                top_logprobs=response.top_logprobs,\n                                is_thinking=response.is_thinking,\n                            ),\n                        )\n                    )\n            case ToolCallResponse():\n                if self.device_rank == 0:\n                    self.event_sender.send(\n                        ChunkGenerated(\n                            command_id=command_id,\n                            chunk=ToolCallChunk(\n                                tool_calls=response.tool_calls,\n                                model=self.model_id,\n                                usage=response.usage,\n                                stats=response.stats,\n                            ),\n                        )\n                    )\n\n\n@dataclass\nclass Builder:\n    model_id: ModelId\n    event_sender: MpSender[Event]\n    cancel_receiver: MpReceiver[TaskId]\n    inference_model: Model | None = None\n    tokenizer: TokenizerWrapper | None = None\n    group: mx.distributed.Group | None = None\n\n    def build(\n        self,\n    ) -> InferenceGenerator:\n        assert self.model_id\n        assert self.inference_model\n        assert self.tokenizer\n\n        tool_parser = None\n        logger.info(\n            f\"model has_tool_calling={self.tokenizer.has_tool_calling} using tokens {self.tokenizer.tool_call_start}, {self.tokenizer.tool_call_end}\"\n        )\n        if (\n            self.tokenizer.tool_call_start\n            and self.tokenizer.tool_call_end\n            and self.tokenizer.tool_parser  # type: ignore\n        ):\n            tool_parser = make_mlx_parser(\n                self.tokenizer.tool_call_start,\n                self.tokenizer.tool_call_end,\n                self.tokenizer.tool_parser,  # type: ignore\n            )\n\n        kv_prefix_cache = KVPrefixCache(self.group)\n\n        device_rank = 0 if self.group is None else self.group.rank()\n        if os.environ.get(\"EXO_NO_BATCH\"):\n            logger.info(\"using SequentialGenerator (batching disabled)\")\n            return SequentialGenerator(\n                model=self.inference_model,\n                tokenizer=self.tokenizer,\n                group=self.group,\n                tool_parser=tool_parser,\n                kv_prefix_cache=kv_prefix_cache,\n                model_id=self.model_id,\n                device_rank=device_rank,\n                cancel_receiver=self.cancel_receiver,\n                event_sender=self.event_sender,\n            )\n        logger.info(\"using BatchGenerator\")\n        return BatchGenerator(\n            model=self.inference_model,\n            tokenizer=self.tokenizer,\n            group=self.group,\n            tool_parser=tool_parser,\n            kv_prefix_cache=kv_prefix_cache,\n            model_id=self.model_id,\n            device_rank=device_rank,\n            cancel_receiver=self.cancel_receiver,\n            event_sender=self.event_sender,\n        )\n"
  },
  {
    "path": "src/exo/worker/runner/llm_inference/tool_parsers.py",
    "content": "import json\nimport math\nfrom dataclasses import dataclass\nfrom typing import Any, Callable\n\nfrom exo.api.types import ToolCallItem\n\n\n@dataclass\nclass ToolParser:\n    start_parsing: str\n    end_parsing: str\n    _inner_parser: Callable[[str], list[ToolCallItem] | None]\n\n    def parse(\n        self, text: str, tools: list[dict[str, Any]] | None\n    ) -> list[ToolCallItem] | None:\n        parsed = self._inner_parser(text)\n        if parsed is None:\n            return None\n        if tools is not None:\n            parsed = _coerce_tool_calls_to_schema(parsed, tools)\n        return parsed\n\n\ndef _json_type_matches(value: Any, expected_type: str) -> bool:  # pyright: ignore[reportAny]\n    if expected_type == \"object\":\n        return isinstance(value, dict)\n    if expected_type == \"array\":\n        return isinstance(value, list)\n    if expected_type == \"string\":\n        return isinstance(value, str)\n    if expected_type == \"integer\":\n        return isinstance(value, int) and not isinstance(value, bool)\n    if expected_type == \"number\":\n        return (isinstance(value, int) and not isinstance(value, bool)) or isinstance(\n            value, float\n        )\n    if expected_type == \"boolean\":\n        return isinstance(value, bool)\n    if expected_type == \"null\":\n        return value is None\n    return False\n\n\ndef _coerce_tool_arg_with_schema(value: Any, schema: dict[str, Any]) -> Any:  # pyright: ignore[reportAny]\n    schema_type = schema.get(\"type\")\n\n    if isinstance(schema_type, list):\n        for candidate in schema_type:  # pyright: ignore[reportUnknownVariableType]\n            if not isinstance(candidate, str):\n                continue\n            if candidate == \"null\" and value is None:\n                return None\n            candidate_schema = {**schema, \"type\": candidate}\n            coerced = _coerce_tool_arg_with_schema(value, candidate_schema)  # pyright: ignore[reportAny]\n            if _json_type_matches(coerced, candidate):\n                return coerced  # pyright: ignore[reportAny]\n        return value  # pyright: ignore[reportAny]\n\n    if not isinstance(schema_type, str):\n        return value  # pyright: ignore[reportAny]\n\n    if schema_type == \"object\":\n        parsed = value  # pyright: ignore[reportAny]\n        if isinstance(parsed, str):\n            try:\n                parsed = json.loads(parsed)  # pyright: ignore[reportAny]\n            except Exception:\n                return value  # pyright: ignore[reportAny]\n        if not isinstance(parsed, dict):\n            return value  # pyright: ignore[reportAny]\n        properties = schema.get(\"properties\")\n        if not isinstance(properties, dict):\n            return parsed  # pyright: ignore[reportUnknownVariableType]\n        return {\n            key: (\n                _coerce_tool_arg_with_schema(prop_value, prop_schema)  # pyright: ignore[reportUnknownArgumentType]\n                if isinstance(prop_schema, dict)\n                else prop_value\n            )\n            for key, prop_value in parsed.items()  # pyright: ignore[reportUnknownVariableType]\n            for prop_schema in [properties.get(key)]  # type: ignore\n        }\n\n    if schema_type == \"array\":\n        parsed = value  # pyright: ignore[reportAny]\n        if isinstance(parsed, str):\n            try:\n                parsed = json.loads(parsed)  # pyright: ignore[reportAny]\n            except Exception:\n                return value  # pyright: ignore[reportAny]\n        if not isinstance(parsed, list):\n            return value  # pyright: ignore[reportAny]\n        item_schema = schema.get(\"items\")\n        if not isinstance(item_schema, dict):\n            return parsed  # pyright: ignore[reportUnknownVariableType]\n        return [_coerce_tool_arg_with_schema(item, item_schema) for item in parsed]  # type: ignore\n\n    if schema_type == \"integer\":\n        if isinstance(value, bool):\n            return value\n        if isinstance(value, int):\n            return value\n        if isinstance(value, float) and value.is_integer():\n            return int(value)\n        if isinstance(value, str):\n            try:\n                return int(value.strip())\n            except ValueError:\n                return value\n        return value\n\n    if schema_type == \"number\":\n        if isinstance(value, bool):\n            return value\n        if isinstance(value, (int, float)):\n            return value\n        if isinstance(value, str):\n            try:\n                num = float(value.strip())\n                if math.isfinite(num):\n                    return num\n            except ValueError:\n                return value\n        return value\n\n    if schema_type == \"boolean\":\n        if isinstance(value, bool):\n            return value\n        if isinstance(value, str):\n            lowered = value.strip().lower()\n            if lowered == \"true\":\n                return True\n            if lowered == \"false\":\n                return False\n        return value\n\n    return value  # pyright: ignore[reportAny]\n\n\ndef _coerce_tool_calls_to_schema(\n    tool_calls: list[ToolCallItem], tools: list[dict[str, Any]]\n) -> list[ToolCallItem]:\n    schema_by_name: dict[str, dict[str, Any]] = {}\n    for tool in tools:\n        function = tool.get(\"function\")\n        if not isinstance(function, dict):\n            continue\n        name = function.get(\"name\")  # type: ignore\n        parameters = function.get(\"parameters\")  # type: ignore\n        if isinstance(name, str) and isinstance(parameters, dict):\n            schema_by_name[name] = parameters\n\n    if not schema_by_name:\n        return tool_calls\n\n    coerced_calls: list[ToolCallItem] = []\n    for tool_call in tool_calls:\n        schema = schema_by_name.get(tool_call.name)\n        if schema is None:\n            coerced_calls.append(tool_call)\n            continue\n\n        try:\n            parsed_args = json.loads(tool_call.arguments)  # pyright: ignore[reportAny]\n        except Exception:\n            coerced_calls.append(tool_call)\n            continue\n\n        if not isinstance(parsed_args, dict):\n            coerced_calls.append(tool_call)\n            continue\n\n        coerced_args = _coerce_tool_arg_with_schema(parsed_args, schema)  # pyright: ignore[reportAny]\n        if not isinstance(coerced_args, dict):\n            coerced_calls.append(tool_call)\n            continue\n\n        coerced_calls.append(\n            tool_call.model_copy(update={\"arguments\": json.dumps(coerced_args)})\n        )\n    return coerced_calls\n\n\ndef make_mlx_parser(\n    tool_call_start: str,\n    tool_call_end: str,\n    tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],\n) -> ToolParser:\n    def parse_tool_calls(text: str) -> list[ToolCallItem] | None:\n        try:\n            text = text.removeprefix(tool_call_start)\n            text = text.removesuffix(tool_call_end)\n            parsed = tool_parser(text)\n            if isinstance(parsed, list):\n                return [ToolCallItem.model_validate(_flatten(p)) for p in parsed]\n            else:\n                return [ToolCallItem.model_validate(_flatten(parsed))]\n\n        except Exception:\n            return None\n\n    return ToolParser(\n        start_parsing=tool_call_start,\n        end_parsing=tool_call_end,\n        _inner_parser=parse_tool_calls,\n    )\n\n\n# TODO / example code:\ndef _parse_json_calls(text: str) -> list[ToolCallItem] | None:\n    try:\n        text = text.removeprefix(\"<tool_call>\")\n        text = text.removesuffix(\"</tool_call>\")\n        top_level = {\n            k: json.dumps(v) if isinstance(v, (dict, list)) else v\n            for k, v in json.loads(text).items()  # pyright: ignore[reportAny]\n        }\n        return [ToolCallItem.model_validate(top_level)]\n    except Exception:\n        return None\n\n\ndef _flatten(p: dict[str, Any]) -> dict[str, str]:\n    return {\n        k: json.dumps(v) if isinstance(v, (dict, list)) else str(v)  # pyright: ignore[reportAny]\n        for k, v in p.items()  # pyright: ignore[reportAny]\n    }\n\n\ndef make_json_parser() -> ToolParser:\n    return ToolParser(\n        start_parsing=\"<tool_call>\",\n        end_parsing=\"</tool_call>\",\n        _inner_parser=_parse_json_calls,\n    )\n\n\ndef infer_tool_parser(chat_template: str) -> ToolParser | None:\n    \"\"\"Attempt to auto-infer a tool parser from the chat template.\"\"\"\n    if \"<tool_call>\" in chat_template and \"tool_call.name\" in chat_template:\n        return make_json_parser()\n    return None\n"
  },
  {
    "path": "src/exo/worker/runner/runner_supervisor.py",
    "content": "import contextlib\nimport multiprocessing as mp\nimport signal\nfrom dataclasses import dataclass, field\nfrom typing import Self\n\nimport anyio\nfrom anyio import (\n    BrokenResourceError,\n    ClosedResourceError,\n    to_thread,\n)\nfrom loguru import logger\n\nfrom exo.shared.types.chunks import ErrorChunk\nfrom exo.shared.types.events import (\n    ChunkGenerated,\n    Event,\n    RunnerStatusUpdated,\n    TaskAcknowledged,\n    TaskStatusUpdated,\n)\nfrom exo.shared.types.tasks import (\n    CANCEL_ALL_TASKS,\n    ImageEdits,\n    ImageGeneration,\n    Task,\n    TaskId,\n    TaskStatus,\n    TextGeneration,\n)\nfrom exo.shared.types.worker.instances import BoundInstance\nfrom exo.shared.types.worker.runners import (\n    RunnerConnecting,\n    RunnerFailed,\n    RunnerIdle,\n    RunnerLoading,\n    RunnerRunning,\n    RunnerShuttingDown,\n    RunnerStatus,\n    RunnerWarmingUp,\n)\nfrom exo.shared.types.worker.shards import ShardMetadata\nfrom exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel\nfrom exo.utils.task_group import TaskGroup\nfrom exo.worker.runner.bootstrap import entrypoint\n\nPREFILL_TIMEOUT_SECONDS = 60\nDECODE_TIMEOUT_SECONDS = 5\n\n\n@dataclass(eq=False)\nclass RunnerSupervisor:\n    shard_metadata: ShardMetadata\n    bound_instance: BoundInstance\n    runner_process: mp.Process\n    initialize_timeout: float\n    _ev_recv: MpReceiver[Event]\n    _task_sender: MpSender[Task]\n    _event_sender: Sender[Event]\n    _cancel_sender: MpSender[TaskId]\n    _tg: TaskGroup = field(default_factory=TaskGroup, init=False)\n    status: RunnerStatus = field(default_factory=RunnerIdle, init=False)\n    pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)\n    in_progress: dict[TaskId, Task] = field(default_factory=dict, init=False)\n    completed: set[TaskId] = field(default_factory=set, init=False)\n    cancelled: set[TaskId] = field(default_factory=set, init=False)\n    _cancel_watch_runner: anyio.CancelScope = field(\n        default_factory=anyio.CancelScope, init=False\n    )\n\n    @classmethod\n    def create(\n        cls,\n        *,\n        bound_instance: BoundInstance,\n        event_sender: Sender[Event],\n        initialize_timeout: float = 400,\n    ) -> Self:\n        ev_send, ev_recv = mp_channel[Event]()\n        task_sender, task_recv = mp_channel[Task]()\n        cancel_sender, cancel_recv = mp_channel[TaskId]()\n\n        runner_process = mp.Process(\n            target=entrypoint,\n            args=(\n                bound_instance,\n                ev_send,\n                task_recv,\n                cancel_recv,\n                logger,\n            ),\n            daemon=True,\n        )\n\n        shard_metadata = bound_instance.bound_shard\n\n        self = cls(\n            bound_instance=bound_instance,\n            shard_metadata=shard_metadata,\n            runner_process=runner_process,\n            initialize_timeout=initialize_timeout,\n            _ev_recv=ev_recv,\n            _task_sender=task_sender,\n            _cancel_sender=cancel_sender,\n            _event_sender=event_sender,\n        )\n\n        return self\n\n    async def run(self):\n        self.runner_process.start()\n        async with self._tg as tg:\n            tg.start_soon(self._watch_runner)\n            tg.start_soon(self._forward_events)\n\n    def shutdown(self):\n        logger.info(\"Runner supervisor shutting down\")\n        self._tg.cancel_tasks()\n        if not self._cancel_watch_runner.cancel_called:\n            self._cancel_watch_runner.cancel()\n        with contextlib.suppress(ClosedResourceError):\n            self._ev_recv.close()\n        with contextlib.suppress(ClosedResourceError):\n            self._task_sender.close()\n        with contextlib.suppress(ClosedResourceError):\n            self._event_sender.close()\n        with contextlib.suppress(ClosedResourceError):\n            self._cancel_sender.send(CANCEL_ALL_TASKS)\n        with contextlib.suppress(ClosedResourceError):\n            self._cancel_sender.close()\n        self.runner_process.join(5)\n        if not self.runner_process.is_alive():\n            logger.info(\"Runner process succesfully terminated\")\n            return\n\n        # This is overkill but it's not technically bad, just unnecessary.\n        logger.warning(\"Runner process didn't shutdown succesfully, terminating\")\n        self.runner_process.terminate()\n        self.runner_process.join(1)\n        if not self.runner_process.is_alive():\n            return\n\n        logger.critical(\"Runner process didn't respond to SIGTERM, killing\")\n        self.runner_process.kill()\n\n    async def start_task(self, task: Task):\n        if task.task_id in self.pending:\n            logger.warning(\n                f\"Skipping invalid task {task} as it has already been submitted\"\n            )\n            return\n        if task.task_id in self.completed:\n            logger.warning(\n                f\"Skipping invalid task {task} as it has already been completed\"\n            )\n            return\n        logger.info(f\"Starting task {task}\")\n        event = anyio.Event()\n        self.pending[task.task_id] = event\n        self.in_progress[task.task_id] = task\n        try:\n            await self._task_sender.send_async(task)\n        except ClosedResourceError:\n            self.in_progress.pop(task.task_id, None)\n            logger.warning(f\"Task {task} dropped, runner closed communication.\")\n            return\n        await event.wait()\n\n    async def cancel_task(self, task_id: TaskId):\n        if task_id in self.completed:\n            logger.info(f\"Unable to cancel {task_id} as it has been completed\")\n            self.cancelled.add(task_id)\n            return\n        self.cancelled.add(task_id)\n        with anyio.move_on_after(0.5) as scope:\n            try:\n                await self._cancel_sender.send_async(task_id)\n            except ClosedResourceError:\n                # typically occurs when trying to shut down a failed instance\n                logger.warning(\n                    f\"Cancelling task {task_id} failed, runner closed communication\"\n                )\n        if scope.cancel_called:\n            logger.error(\"RunnerSupervisor cancel pipe blocked\")\n            await self._check_runner(TimeoutError(\"cancel pipe blocked\"))\n\n    async def _forward_events(self):\n        try:\n            with self._ev_recv as events:\n                async for event in events:\n                    if isinstance(event, RunnerStatusUpdated):\n                        self.status = event.runner_status\n                    if isinstance(event, TaskAcknowledged):\n                        self.pending.pop(event.task_id).set()\n                        continue\n                    if (\n                        isinstance(event, TaskStatusUpdated)\n                        and event.task_status == TaskStatus.Complete\n                    ):\n                        # If a task has just been completed, we should be working on it.\n                        assert isinstance(\n                            self.status,\n                            (\n                                RunnerRunning,\n                                RunnerWarmingUp,\n                                RunnerLoading,\n                                RunnerConnecting,\n                                RunnerShuttingDown,\n                            ),\n                        )\n                        self.in_progress.pop(event.task_id, None)\n                        self.completed.add(event.task_id)\n                    await self._event_sender.send(event)\n        except (ClosedResourceError, BrokenResourceError) as e:\n            await self._check_runner(e)\n        finally:\n            for tid in self.pending:\n                self.pending[tid].set()\n\n    def __del__(self) -> None:\n        if self.runner_process.is_alive():\n            logger.critical(\"RunnerSupervisor was not stopped cleanly.\")\n            with contextlib.suppress(ValueError):\n                self.runner_process.kill()\n\n    async def _watch_runner(self) -> None:\n        with self._cancel_watch_runner:\n            while True:\n                await anyio.sleep(5)\n                if not self.runner_process.is_alive():\n                    await self._check_runner(RuntimeError(\"Runner found to be dead\"))\n\n    async def _check_runner(self, e: Exception) -> None:\n        if not self._cancel_watch_runner.cancel_called:\n            self._cancel_watch_runner.cancel()\n        logger.info(\"Checking runner's status\")\n        if self.runner_process.is_alive():\n            logger.info(\"Runner was found to be alive, attempting to join process\")\n            await to_thread.run_sync(self.runner_process.join, 5)\n        rc = self.runner_process.exitcode\n        logger.info(f\"Runner exited with exit code {rc}\")\n        if rc == 0:\n            return\n\n        if isinstance(rc, int) and rc < 0:\n            sig = -rc\n            try:\n                cause = f\"signal={sig} ({signal.strsignal(sig)})\"\n            except Exception:\n                cause = f\"signal={sig}\"\n        else:\n            cause = f\"exitcode={rc}\"\n\n        logger.opt(exception=e).error(f\"Runner terminated with {cause}\")\n\n        for task in self.in_progress.values():\n            if isinstance(task, (TextGeneration, ImageGeneration, ImageEdits)):\n                with anyio.CancelScope(shield=True):\n                    await self._event_sender.send(\n                        ChunkGenerated(\n                            command_id=task.command_id,\n                            chunk=ErrorChunk(\n                                model=self.shard_metadata.model_card.model_id,\n                                error_message=(\n                                    \"Runner shutdown before completing command \"\n                                    f\"({cause})\"\n                                ),\n                            ),\n                        )\n                    )\n\n        try:\n            self.status = RunnerFailed(error_message=f\"Terminated ({cause})\")\n            with anyio.CancelScope(shield=True):\n                await self._event_sender.send(\n                    RunnerStatusUpdated(\n                        runner_id=self.bound_instance.bound_runner_id,\n                        runner_status=RunnerFailed(\n                            error_message=f\"Terminated ({cause})\"\n                        ),\n                    )\n                )\n        except (ClosedResourceError, BrokenResourceError):\n            logger.warning(\n                \"Event sender already closed, unable to report runner failure\"\n            )\n        self.shutdown()\n"
  },
  {
    "path": "src/exo/worker/tests/TODO.tests",
    "content": "Unit Tests\n1. Test worker plans as expected\n - State transitions are correct\n - Unexpected states throw\n\n2. Test runner\n - Stays loaded\n - Unloads under end condition\n - Accepts tasks\n - Returns ChunkGenerated events\n\n3. Test mlx engine\n - Autoparallel on n of the same nodes returns tensors with 1/n size\n - mx.barrier forces computation\n - Distributed init returns expected configuration\n - initialize_mlx sets wired limit\n - shard_and_load returns expected model\n - Quantization returns quantized layers\n\n 4. Download\n  - hits the correct endpoint\n  - normalizes tags correctly\n  - updates download progress\n\n 5. Serialization/Deserialization of tagged models\n\n\n\n\n\nIntegration tests:\n1. Test model inference is \"sensible\" (per-configuration)\n - Non-empty response\n - Sensible inference speed\n - Answers are non-gibberish for many seeds (What is the capital of France? -> \"Paris\" in answer.)\n - Answer is the same for particular seed\n\n2. Test that node count does not affect inference result (per-configuration)\n - Llama on 1 node, and on 2 nodes returns the same result, given temperature 0 and set seed.\n - Do for all configurations (Ring/Jaccl, Pipeline/Tensor)\n\n3. Test supervisor catches exceptions gracefully\n - Timeouts\n - OOM\n - MLX error\n\n4. distributed init memory requirements are as expected\n\n5. MLX\n - KVCache size is same length as prompt tokens\n - Prefix cache (once implemented)\n\n6. Spin up creates a runner or goes to failed status\n\n\nRegression tests:\n1. Per-configuration baseline performance - no 20% drop in performance (device, node count, model, strategy, backend)\n"
  },
  {
    "path": "src/exo/worker/tests/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/worker/tests/constants.py",
    "content": "from typing import Final\n\nfrom exo.shared.models.model_cards import ModelId\nfrom exo.shared.types.common import CommandId, NodeId\nfrom exo.shared.types.tasks import TaskId\nfrom exo.shared.types.worker.instances import InstanceId, RunnerId\n\nMASTER_NODE_ID = NodeId(\"ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa\")\n\nNODE_A: Final[NodeId] = NodeId(\"aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa\")\nNODE_B: Final[NodeId] = NodeId(\"bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb\")\nNODE_C: Final[NodeId] = NodeId(\"cccccccc-cccc-4ccc-8ccc-cccccccccccc\")\n\nRUNNER_1_ID: Final[RunnerId] = RunnerId(\"11111111-1111-4111-8111-111111111111\")\nRUNNER_2_ID: Final[RunnerId] = RunnerId(\"33333333-3333-4333-8333-333333333333\")\nRUNNER_3_ID: Final[RunnerId] = RunnerId(\"Runner3\")\n\nINSTANCE_1_ID: Final[InstanceId] = InstanceId(\"22222222-2222-4222-8222-222222222222\")\nINSTANCE_2_ID: Final[InstanceId] = InstanceId(\"44444444-4444-4444-8444-444444444444\")\n\nMODEL_A_ID: Final[ModelId] = ModelId(\"mlx-community/Llama-3.2-1B-Instruct-4bit\")\nMODEL_B_ID: Final[ModelId] = ModelId(\"mlx-community/TinyLlama-1.1B-Chat-v1.0\")\n\nTASK_1_ID: Final[TaskId] = TaskId(\"55555555-5555-4555-8555-555555555555\")\nTASK_2_ID: Final[TaskId] = TaskId(\"66666666-6666-4666-8666-666666666666\")\n\nCOMMAND_1_ID: Final[CommandId] = CommandId(\"77777777-7777-4777-8777-777777777777\")\nCOMMAND_2_ID: Final[CommandId] = CommandId(\"88888888-8888-4888-8888-888888888888\")\n\nSHUTDOWN_TASK_ID = TaskId(\"shutdown\")\nCHAT_COMPLETION_TASK_ID = TaskId(\"chat-completion\")\nINITIALIZATION_TASK_ID = TaskId(\"initialisation\")\nLOAD_TASK_ID = TaskId(\"load\")\nWARMUP_TASK_ID = TaskId(\"warmup\")\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/worker/tests/unittests/conftest.py",
    "content": "from dataclasses import dataclass, field\n\nfrom exo.shared.models.model_cards import ModelCard, ModelId, ModelTask\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.tasks import BaseTask, TaskId\nfrom exo.shared.types.worker.instances import (\n    BoundInstance,\n    Instance,\n    InstanceId,\n    MlxRingInstance,\n)\nfrom exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignments\nfrom exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata\n\n\n# Runner supervisor without multiprocessing logic.\n@dataclass(frozen=True)\nclass FakeRunnerSupervisor:\n    bound_instance: BoundInstance\n    status: RunnerStatus\n    completed: set[TaskId] = field(default_factory=set)\n    in_progress: set[TaskId] = field(default_factory=set)\n    pending: dict[TaskId, object] = field(default_factory=dict)\n\n\nclass OtherTask(BaseTask):\n    pass\n\n\n# TODO: Is this actually better than using Mock/Fake dataclasses?\n#  e.g. commit d01cd292344df15759070966826a6c027945792b\ndef get_pipeline_shard_metadata(\n    model_id: ModelId, device_rank: int, world_size: int = 1\n) -> ShardMetadata:\n    return PipelineShardMetadata(\n        model_card=ModelCard(\n            model_id=model_id,\n            storage_size=Memory.from_mb(100000),\n            n_layers=32,\n            hidden_size=2048,\n            supports_tensor=False,\n            tasks=[ModelTask.TextGeneration],\n        ),\n        device_rank=device_rank,\n        world_size=world_size,\n        start_layer=0,\n        end_layer=32,\n        n_layers=32,\n    )\n\n\ndef get_shard_assignments(\n    model_id: ModelId,\n    node_to_runner: dict[NodeId, RunnerId],\n    runner_to_shard: dict[RunnerId, ShardMetadata],\n) -> ShardAssignments:\n    return ShardAssignments(\n        model_id=model_id,\n        node_to_runner=node_to_runner,\n        runner_to_shard=runner_to_shard,\n    )\n\n\ndef get_mlx_ring_instance(\n    instance_id: InstanceId,\n    model_id: ModelId,\n    node_to_runner: dict[NodeId, RunnerId],\n    runner_to_shard: dict[RunnerId, ShardMetadata],\n) -> Instance:\n    return MlxRingInstance(\n        instance_id=instance_id,\n        shard_assignments=get_shard_assignments(\n            model_id, node_to_runner, runner_to_shard\n        ),\n        hosts_by_node={},\n        ephemeral_port=50000,\n    )\n\n\ndef get_bound_mlx_ring_instance(\n    instance_id: InstanceId, model_id: ModelId, runner_id: RunnerId, node_id: NodeId\n) -> BoundInstance:\n    shard = get_pipeline_shard_metadata(model_id=model_id, device_rank=0, world_size=2)\n    other_shard = get_pipeline_shard_metadata(\n        model_id=model_id, device_rank=1, world_size=2\n    )\n    instance = get_mlx_ring_instance(\n        instance_id=instance_id,\n        model_id=model_id,\n        node_to_runner={\n            node_id: runner_id,\n            NodeId(\"other_node\"): RunnerId(\"other_runner\"),\n        },\n        runner_to_shard={runner_id: shard, RunnerId(\"other_runner\"): other_shard},\n    )\n    return BoundInstance(\n        instance=instance, bound_runner_id=runner_id, bound_node_id=node_id\n    )\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_download/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/worker/tests/unittests/test_mlx/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/worker/tests/unittests/test_mlx/conftest.py",
    "content": "# type: ignore\nimport json\nimport os\nimport tempfile\nimport traceback\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any, cast\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom exo.shared.constants import EXO_MODELS_DIR\nfrom exo.shared.models.model_cards import ModelCard, ModelTask\nfrom exo.shared.types.common import ModelId\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.mlx import Model\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\nfrom exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata\nfrom exo.worker.engines.mlx.generator.generate import mlx_generate\nfrom exo.worker.engines.mlx.utils_mlx import apply_chat_template, shard_and_load\n\n\nclass MockLayer(nn.Module):\n    def __init__(self) -> None:\n        super().__init__()\n        self.custom_attr = \"test_value\"\n        self.use_sliding = True\n\n    def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:\n        return x * 2\n\n\n@dataclass(frozen=True)\nclass PipelineTestConfig:\n    model_path: Path\n    total_layers: int\n    base_port: int\n    max_tokens: int\n\n\ndef create_hostfile(world_size: int, base_port: int) -> tuple[str, list[str]]:\n    hosts = [f\"127.0.0.1:{base_port + i}\" for i in range(world_size)]\n\n    with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".json\", delete=False) as f:\n        json.dump(hosts, f)\n        hostfile_path = f.name\n\n    return hostfile_path, hosts\n\n\n# Use GPT OSS 20b to test as it is a model with a lot of strange behaviour\n\nDEFAULT_GPT_OSS_CONFIG = PipelineTestConfig(\n    model_path=EXO_MODELS_DIR / \"mlx-community--gpt-oss-20b-MXFP4-Q8\",\n    total_layers=24,\n    base_port=29600,\n    max_tokens=200,\n)\n\n\nDEFAULT_GPT_OSS_MODEL_ID = \"mlx-community/gpt-oss-20b-MXFP4-Q8\"\n\n\ndef run_gpt_oss_pipeline_device(\n    rank: int,\n    world_size: int,\n    hostfile_path: str,\n    layer_splits: list[tuple[int, int]],\n    prompt_tokens: int,\n    prefill_step_size: int,\n    result_queue: Any,  # pyright: ignore[reportAny]\n    max_tokens: int = 200,\n) -> None:\n    os.environ[\"MLX_HOSTFILE\"] = hostfile_path\n    os.environ[\"MLX_RANK\"] = str(rank)\n\n    try:\n        group = mx.distributed.init(backend=\"ring\", strict=True)\n\n        start_layer, end_layer = layer_splits[rank]\n\n        shard_meta = PipelineShardMetadata(\n            model_card=ModelCard(\n                model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),\n                storage_size=Memory.from_gb(12),\n                n_layers=24,\n                hidden_size=2880,\n                supports_tensor=False,\n                tasks=[ModelTask.TextGeneration],\n            ),\n            device_rank=rank,\n            world_size=world_size,\n            start_layer=start_layer,\n            end_layer=end_layer,\n            n_layers=24,\n        )\n\n        model, tokenizer = shard_and_load(\n            shard_meta, group, on_timeout=None, on_layer_loaded=None\n        )\n        model = cast(Model, model)\n\n        # Generate a prompt of exact token length\n        base_text = \"The quick brown fox jumps over the lazy dog. \"\n        base_tokens = tokenizer.encode(base_text)\n        base_len = len(base_tokens)\n\n        # Build prompt with approximate target length\n        repeats = (prompt_tokens // base_len) + 2\n        long_text = base_text * repeats\n        tokens = tokenizer.encode(long_text)\n        # Truncate to exact target length\n        tokens = tokens[:prompt_tokens]\n        prompt_text = tokenizer.decode(tokens)\n\n        task = TextGenerationTaskParams(\n            model=DEFAULT_GPT_OSS_MODEL_ID,\n            input=[InputMessage(role=\"user\", content=prompt_text)],\n            max_output_tokens=max_tokens,\n        )\n\n        prompt = apply_chat_template(tokenizer, task)\n\n        generated_text = \"\"\n\n        for response in mlx_generate(\n            model=model,\n            tokenizer=tokenizer,\n            task=task,\n            prompt=prompt,\n            kv_prefix_cache=None,\n            group=group,\n        ):\n            generated_text += response.text\n            if response.finish_reason is not None:\n                break\n\n        result_queue.put((rank, True, generated_text))  # pyright: ignore[reportAny]\n\n    except Exception as e:\n        result_queue.put((rank, False, f\"{e}\\n{traceback.format_exc()}\"))  # pyright: ignore[reportAny]\n\n\ndef run_gpt_oss_tensor_parallel_device(\n    rank: int,\n    world_size: int,\n    hostfile_path: str,\n    prompt_tokens: int,\n    prefill_step_size: int,\n    result_queue: Any,  # pyright: ignore[reportAny]\n    max_tokens: int = 10,\n) -> None:\n    os.environ[\"MLX_HOSTFILE\"] = hostfile_path\n    os.environ[\"MLX_RANK\"] = str(rank)\n\n    try:\n        group = mx.distributed.init(backend=\"ring\", strict=True)\n\n        # For tensor parallelism, all devices run all layers\n        shard_meta = TensorShardMetadata(\n            model_card=ModelCard(\n                model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),\n                storage_size=Memory.from_gb(12),\n                n_layers=24,\n                hidden_size=2880,\n                supports_tensor=True,\n                tasks=[ModelTask.TextGeneration],\n            ),\n            device_rank=rank,\n            world_size=world_size,\n            start_layer=0,\n            end_layer=24,\n            n_layers=24,\n        )\n\n        model, tokenizer = shard_and_load(\n            shard_meta, group, on_timeout=None, on_layer_loaded=None\n        )\n        model = cast(Model, model)\n\n        base_text = \"The quick brown fox jumps over the lazy dog. \"\n        base_tokens = tokenizer.encode(base_text)\n        base_len = len(base_tokens)\n\n        repeats = (prompt_tokens // base_len) + 2\n        long_text = base_text * repeats\n        tokens = tokenizer.encode(long_text)\n        tokens = tokens[:prompt_tokens]\n        prompt_text = tokenizer.decode(tokens)\n\n        task = TextGenerationTaskParams(\n            model=DEFAULT_GPT_OSS_MODEL_ID,\n            input=[InputMessage(role=\"user\", content=prompt_text)],\n            max_output_tokens=max_tokens,\n        )\n\n        prompt = apply_chat_template(tokenizer, task)\n\n        generated_text = \"\"\n        for response in mlx_generate(\n            model=model,\n            tokenizer=tokenizer,\n            task=task,\n            prompt=prompt,\n            kv_prefix_cache=None,\n            group=group,\n        ):\n            generated_text += response.text\n            if response.finish_reason is not None:\n                break\n\n        result_queue.put((rank, True, generated_text))  # pyright: ignore[reportAny]\n\n    except Exception as e:\n        result_queue.put((rank, False, f\"{e}\\n{traceback.format_exc()}\"))  # pyright: ignore[reportAny]\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_mlx/test_auto_parallel.py",
    "content": "import json\nimport multiprocessing as mp\nimport os\nimport tempfile\nfrom typing import Any\n\nimport mlx.core as mx\nimport mlx.nn as mlx_nn\nimport pytest\n\nfrom exo.worker.engines.mlx.auto_parallel import (\n    CustomMlxLayer,\n    PipelineFirstLayer,\n    PipelineLastLayer,\n    patch_pipeline_model,\n)\nfrom exo.worker.tests.unittests.test_mlx.conftest import MockLayer\n\n\ndef run_pipeline_device(\n    rank: int,\n    world_size: int,\n    hostfile_path: str,\n    result_queue: Any,  # pyright: ignore[reportAny]\n) -> None:\n    import os\n\n    os.environ[\"MLX_HOSTFILE\"] = hostfile_path\n    os.environ[\"MLX_RANK\"] = str(rank)\n\n    class MockLayerInner(mlx_nn.Module):\n        def __init__(self) -> None:\n            super().__init__()\n            self.custom_attr = \"test_value\"\n\n        def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:\n            return x * 2\n\n    class MockModel(mlx_nn.Module):\n        def __init__(self, layers: list[mlx_nn.Module]) -> None:\n            super().__init__()\n            self.layers = layers\n\n        def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:\n            for layer in self.layers:\n                x = layer(x, *args, **kwargs)  # pyright: ignore[reportUnknownVariableType]\n            return x  # pyright: ignore[reportUnknownVariableType]\n\n    try:\n        group = mx.distributed.init(backend=\"ring\", strict=True)\n\n        mock = MockLayerInner()\n        first = PipelineFirstLayer(mock, r=rank, group=group)\n        composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)\n\n        # Wrap in a mock model, then wrap in PipelineParallelModel for all_gather\n        inner_model = MockModel([composed])\n        model = patch_pipeline_model(inner_model, group)\n\n        x = mx.ones((1, 4))\n        result = model(x)\n        mx.eval(result)\n        success = result.shape == x.shape\n        result_queue.put((rank, success, result))  # pyright: ignore[reportAny]\n    except Exception as e:\n        result_queue.put((rank, False, str(e)))  # pyright: ignore[reportAny]\n\n\ndef test_single_wrapper_delegates_attributes() -> None:\n    mock = MockLayer()\n    wrapped = CustomMlxLayer(mock)\n\n    assert wrapped.custom_attr == \"test_value\"  # type: ignore[attr-defined]\n    assert wrapped.use_sliding is True  # type: ignore[attr-defined]\n\n\ndef test_composed_wrappers_delegate_attributes() -> None:\n    mock = MockLayer()\n    group = mx.distributed.init()\n\n    first = PipelineFirstLayer(mock, r=0, group=group)\n    composed = PipelineLastLayer(first, r=0, s=1, group=group)\n\n    assert composed.custom_attr == \"test_value\"  # type: ignore[attr-defined]\n    assert composed.use_sliding is True  # type: ignore[attr-defined]\n\n\ndef test_missing_attribute_raises() -> None:\n    mock = MockLayer()\n    wrapped = CustomMlxLayer(mock)\n\n    with pytest.raises(AttributeError):\n        _ = wrapped.nonexistent_attr  # type: ignore[attr-defined]\n\n\ndef test_composed_call_works() -> None:\n    ctx = mp.get_context(\"spawn\")\n\n    world_size = 2\n    base_port = 29500\n\n    hosts = [f\"127.0.0.1:{base_port + i}\" for i in range(world_size)]\n\n    with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".json\", delete=False) as f:\n        json.dump(hosts, f)\n        hostfile_path = f.name\n\n    try:\n        result_queue: Any = ctx.Queue()\n\n        processes: list[Any] = []\n        for rank in range(world_size):\n            p = ctx.Process(\n                target=run_pipeline_device,\n                args=(rank, world_size, hostfile_path, result_queue),\n            )\n            p.start()\n            processes.append(p)\n\n        for p in processes:  # pyright: ignore[reportAny]\n            p.join(timeout=10)  # pyright: ignore[reportAny]\n\n        results: dict[int, Any] = {}\n        errors: dict[int, str] = {}\n        while not result_queue.empty():  # pyright: ignore[reportAny]\n            rank, success, value = result_queue.get()  # pyright: ignore[reportAny]\n            if success:\n                results[rank] = value\n            else:\n                errors[rank] = value\n\n        assert len(results) == world_size, (\n            f\"Expected {world_size} results, got {len(results)}. Errors: {errors}\"\n        )\n\n        for rank in range(world_size):\n            assert rank in results, (\n                f\"Device {rank} failed: {errors.get(rank, 'unknown')}\"\n            )\n            result_array = results[rank]\n            # Both devices see the final result (4.0) after all_gather\n            assert (result_array == 4.0).all(), (\n                f\"Device {rank}: expected 4.0, got {result_array}\"\n            )\n    finally:\n        os.unlink(hostfile_path)\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_mlx/test_batch_vs_generate.py",
    "content": "import copy\nimport gc\nimport json\nimport shutil\nimport tempfile\nfrom pathlib import Path\nfrom typing import Any, cast\n\nimport mlx.core as mx\nimport pytest\nfrom mlx_lm.tokenizer_utils import TokenizerWrapper\n\nfrom exo.shared.types.common import ModelId\nfrom exo.shared.types.mlx import KVCacheType, Model\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\nfrom exo.worker.engines.mlx.cache import CacheSnapshot, KVPrefixCache, cache_length\nfrom exo.worker.engines.mlx.generator.batch_generate import ExoBatchGenerator\nfrom exo.worker.engines.mlx.generator.generate import mlx_generate\nfrom exo.worker.engines.mlx.utils_mlx import (\n    apply_chat_template,\n    load_tokenizer_for_model_id,\n)\n\nfrom .test_prefix_cache_architectures import (\n    ARCHITECTURES,\n    ArchSpec,\n    _arch_available,  # pyright: ignore[reportPrivateUsage]\n    _build_model,  # pyright: ignore[reportPrivateUsage]\n    _copy_tokenizer,  # pyright: ignore[reportPrivateUsage]\n    _find_snapshot,  # pyright: ignore[reportPrivateUsage]\n    _reduce_config,  # pyright: ignore[reportPrivateUsage]\n)\n\n\ndef _make_task(\n    content: str = \"Hello, what is 2+2?\",\n    max_tokens: int = 10,\n    seed: int = 42,\n) -> TextGenerationTaskParams:\n    return TextGenerationTaskParams(\n        model=ModelId(\"test\"),\n        input=[InputMessage(role=\"user\", content=content)],\n        max_output_tokens=max_tokens,\n        temperature=0.7,\n        seed=seed,\n    )\n\n\n# ── Helpers ──────────────────────────────────────────────────────────────── #\n\n\ndef _collect_mlx_generate(\n    model: Model,\n    tokenizer: TokenizerWrapper,\n    task: TextGenerationTaskParams,\n    kv_prefix_cache: KVPrefixCache | None,\n) -> list[int]:\n    \"\"\"Run mlx_generate and collect output token IDs.\"\"\"\n    prompt = apply_chat_template(tokenizer=tokenizer, task_params=task)\n    tokens: list[int] = []\n    for resp in mlx_generate(\n        model=model,\n        tokenizer=tokenizer,\n        task=task,\n        prompt=prompt,\n        kv_prefix_cache=kv_prefix_cache,\n        group=None,\n    ):\n        tokens.append(resp.token)\n        if resp.finish_reason is not None:\n            break\n    return tokens\n\n\ndef _collect_batch_generate(\n    model: Model,\n    tokenizer: TokenizerWrapper,\n    task_params: TextGenerationTaskParams,\n    kv_prefix_cache: KVPrefixCache | None,\n) -> list[int]:\n    \"\"\"Run ExoBatchGenerator and collect raw output token IDs\"\"\"\n    exo_gen = ExoBatchGenerator(\n        model=model,\n        tokenizer=tokenizer,\n        group=None,\n        kv_prefix_cache=kv_prefix_cache,\n    )\n\n    prompt = apply_chat_template(tokenizer=tokenizer, task_params=task_params)\n    exo_gen.submit(task_params=task_params, prompt=prompt)\n\n    tokens: list[int] = []\n    while exo_gen.has_work:\n        results = exo_gen.step()\n        for _uid, response in results:\n            tokens.append(response.token)\n\n    exo_gen.close()\n    return tokens\n\n\ndef _assert_state_equal(sa: object, sb: object, label: str) -> None:\n    \"\"\"Compare two state items, handling both plain arrays and tuples of arrays (CacheList).\"\"\"\n    if isinstance(sa, tuple):\n        assert isinstance(sb, tuple), f\"{label}: type mismatch\"\n        for k, (arr_a, arr_b) in enumerate(\n            zip(\n                cast(tuple[mx.array, ...], sa),\n                cast(tuple[mx.array, ...], sb),\n                strict=True,\n            )\n        ):\n            a_f = mx.array(arr_a).astype(mx.float32)\n            b_f = mx.array(arr_b).astype(mx.float32)\n            if a_f.size == 0:\n                assert b_f.size == 0, f\"{label}[{k}]: size mismatch\"\n                continue\n            diff = float(mx.max(mx.abs(a_f - b_f)).item())\n            assert diff == 0.0, f\"{label}[{k}]: max diff {diff}\"\n    else:\n        sa_f = mx.array(cast(mx.array, sa)).astype(mx.float32)\n        sb_f = mx.array(cast(mx.array, sb)).astype(mx.float32)\n        if sa_f.size == 0:\n            assert sb_f.size == 0, f\"{label}: size mismatch\"\n            return\n        diff = float(mx.max(mx.abs(sa_f - sb_f)).item())\n        assert diff == 0.0, f\"{label}: max diff {diff}\"\n\n\ndef _compare_cache_arrays(\n    cache_a: KVCacheType,\n    cache_b: KVCacheType,\n    label: str = \"\",\n) -> None:\n    \"\"\"Assert two KV caches have identical array values.\"\"\"\n    assert len(cache_a) == len(cache_b), (\n        f\"{label}Cache layer count: {len(cache_a)} vs {len(cache_b)}\"\n    )\n    for i, (a, b) in enumerate(zip(cache_a, cache_b, strict=True)):\n        assert type(a) is type(b), (\n            f\"{label}Layer {i}: type {type(a).__name__} vs {type(b).__name__}\"\n        )\n        states_a = a.state\n        states_b = b.state\n        assert len(states_a) == len(states_b), (\n            f\"{label}Layer {i}: state count {len(states_a)} vs {len(states_b)}\"\n        )\n        for j, (sa, sb) in enumerate(zip(states_a, states_b, strict=True)):\n            if sa is None and sb is None:\n                continue\n            assert sa is not None and sb is not None, (\n                f\"{label}Layer {i}, state {j}: one is None\"\n            )\n            _assert_state_equal(sa, sb, f\"{label}Layer {i}, state {j}\")\n\n\ndef _safe_state(cache: object) -> list[object]:\n    \"\"\"Safely access .state on a cache object. Returns [] if uninitialized.\"\"\"\n    # RotatingKVCache.state crashes when keys is None (uninitialized)\n    if getattr(cache, \"keys\", _SENTINEL) is None:\n        return []\n    try:\n        return list(cache.state)  # type: ignore[union-attr]\n    except (AttributeError, TypeError):\n        return []\n\n\n_SENTINEL = object()\n\n\ndef _compare_snapshots(\n    snaps_a: list[CacheSnapshot] | None,\n    snaps_b: list[CacheSnapshot] | None,\n    label: str = \"\",\n) -> None:\n    \"\"\"Assert two snapshot lists are identical.\"\"\"\n    if snaps_a is None:\n        assert snaps_b is None, f\"{label}One side has snapshots, other doesn't\"\n        return\n    assert snaps_b is not None, f\"{label}One side has snapshots, other doesn't\"\n    assert len(snaps_a) == len(snaps_b), (\n        f\"{label}Snapshot count: {len(snaps_a)} vs {len(snaps_b)}\"\n    )\n    for k, (sa, sb) in enumerate(zip(snaps_a, snaps_b, strict=True)):\n        assert sa.token_count == sb.token_count, (\n            f\"{label}Snapshot {k} token_count: {sa.token_count} vs {sb.token_count}\"\n        )\n        for layer_i, (s1, s2) in enumerate(zip(sa.states, sb.states, strict=True)):\n            if s1 is None and s2 is None:\n                continue\n            assert s1 is not None and s2 is not None, (\n                f\"{label}Snapshot {k}, layer {layer_i}: one state is None\"\n            )\n            state_a = _safe_state(s1)\n            state_b = _safe_state(s2)\n            if not state_a and not state_b:\n                continue\n            assert len(state_a) == len(state_b), (\n                f\"{label}Snapshot {k}, layer {layer_i}: state length mismatch\"\n            )\n            for st_j, (arr_a, arr_b) in enumerate(zip(state_a, state_b, strict=True)):\n                if arr_a is None and arr_b is None:\n                    continue\n                assert arr_a is not None and arr_b is not None\n                _assert_state_equal(\n                    arr_a,\n                    arr_b,\n                    f\"{label}Snapshot {k}, layer {layer_i}, state {st_j}\",\n                )\n\n\n# ── Test class ────────────────────────────────────────────────────────────── #\n\n\n@pytest.mark.slow\nclass TestBatchVsGenerate:\n    \"\"\"Verify BatchGenerator matches mlx_generate for output tokens and prefix cache.\"\"\"\n\n    @pytest.fixture(autouse=True)\n    def _cleanup(self):\n        yield\n        mx.clear_cache()\n        gc.collect()\n\n    @pytest.mark.parametrize(\n        \"spec\",\n        ARCHITECTURES,\n        ids=[a.name for a in ARCHITECTURES],\n    )\n    def test_same_output_and_cache(self, spec: ArchSpec) -> None:\n        if not _arch_available(spec):\n            pytest.skip(f\"Model {spec.hub_name} not cached locally\")\n\n        snapshot = _find_snapshot(spec.hub_name)\n        assert snapshot is not None\n\n        tmpdir = Path(tempfile.mkdtemp(prefix=f\"exo_batchtest_{spec.name}_\"))\n        try:\n            # Build reduced config\n            with open(snapshot / \"config.json\") as f:\n                cfg = cast(dict[str, Any], json.load(f))\n            reduced = _reduce_config(copy.deepcopy(cfg))\n            (tmpdir / \"config.json\").write_text(json.dumps(reduced))\n\n            # Copy tokenizer\n            tok_src = snapshot\n            if spec.tokenizer_hub is not None:\n                alt = _find_snapshot(spec.tokenizer_hub)\n                if alt is not None:\n                    tok_src = alt\n            _copy_tokenizer(tok_src, tmpdir)\n\n            # Load tokenizer, build model with random weights\n            model_id = ModelId(f\"mlx-community/{spec.hub_name}\")\n            tokenizer = load_tokenizer_for_model_id(model_id, tmpdir)\n            mx.random.seed(0)\n            model = _build_model(spec.module, reduced)\n\n            task = _make_task()\n\n            # ── Run mlx_generate path ──\n            # Seed is set inside mlx_generate/ExoBatchGenerator.submit from task.seed\n            kv_mlx = KVPrefixCache(None)\n            mlx_tokens = _collect_mlx_generate(model, tokenizer, task, kv_mlx)\n\n            # ── Run batch generator path ──\n            kv_batch = KVPrefixCache(None)\n            batch_tokens = _collect_batch_generate(model, tokenizer, task, kv_batch)\n\n            # ── Compare output tokens ──\n            assert len(mlx_tokens) > 0, \"mlx_generate produced no tokens\"\n            assert len(batch_tokens) > 0, \"BatchGenerator produced no tokens\"\n            assert mlx_tokens == batch_tokens, (\n                f\"[{spec.name}] Token mismatch:\\n\"\n                f\"  mlx_generate:    {mlx_tokens}\\n\"\n                f\"  BatchGenerator:  {batch_tokens}\"\n            )\n\n            # ── Compare prefix cache KV arrays ──\n            assert len(kv_mlx.caches) == 1, \"mlx_generate didn't save to prefix cache\"\n            assert len(kv_batch.caches) == 1, (\n                \"BatchGenerator didn't save to prefix cache\"\n            )\n\n            _compare_cache_arrays(\n                kv_mlx.caches[0],\n                kv_batch.caches[0],\n                label=f\"[{spec.name}] \",\n            )\n\n            # ── Compare cache lengths ──\n            mlx_len = cache_length(kv_mlx.caches[0])\n            batch_len = cache_length(kv_batch.caches[0])\n            assert mlx_len == batch_len, (\n                f\"[{spec.name}] Cache length: mlx={mlx_len} vs batch={batch_len}\"\n            )\n\n            # ── Compare snapshots ──\n            _compare_snapshots(\n                kv_mlx._snapshots[0],  # pyright: ignore[reportPrivateUsage]\n                kv_batch._snapshots[0],  # pyright: ignore[reportPrivateUsage]\n                label=f\"[{spec.name}] \",\n            )\n\n        finally:\n            shutil.rmtree(tmpdir, ignore_errors=True)\n\n    @pytest.mark.parametrize(\n        \"spec\",\n        ARCHITECTURES,\n        ids=[a.name for a in ARCHITECTURES],\n    )\n    def test_concurrent_batch_completes(self, spec: ArchSpec) -> None:\n        \"\"\"Two requests processed concurrently must both complete without\n        crashing and produce non-empty output.\n\n        Note: batch decode logits are NOT bit-exact with sequential because\n        Metal's matmul kernel picks different reduction tiling for B=1 vs B=2\n        when L=1 (decode step). This introduces sub-ULP float16 diffs in\n        gate_proj/down_proj/lm_head which swiglu amplifies by |up_values|.\n        With random weights these accumulate into argmax flips; with trained\n        weights the diffs are absorbed and output matches exactly (verified\n        with real Llama-3.2-1B-Instruct-4bit weights).\n        \"\"\"\n        if not _arch_available(spec):\n            pytest.skip(f\"Model {spec.hub_name} not cached locally\")\n\n        snapshot = _find_snapshot(spec.hub_name)\n        assert snapshot is not None\n\n        tmpdir = Path(tempfile.mkdtemp(prefix=f\"exo_concurrent_{spec.name}_\"))\n        try:\n            with open(snapshot / \"config.json\") as f:\n                cfg = cast(dict[str, Any], json.load(f))\n            reduced = _reduce_config(copy.deepcopy(cfg))\n            (tmpdir / \"config.json\").write_text(json.dumps(reduced))\n\n            tok_src = snapshot\n            if spec.tokenizer_hub is not None:\n                alt = _find_snapshot(spec.tokenizer_hub)\n                if alt is not None:\n                    tok_src = alt\n            _copy_tokenizer(tok_src, tmpdir)\n\n            model_id = ModelId(f\"mlx-community/{spec.hub_name}\")\n            tokenizer = load_tokenizer_for_model_id(model_id, tmpdir)\n            mx.random.seed(0)\n            model = _build_model(spec.module, reduced)\n\n            # Two different prompts → different prompt lengths.\n            task_a = _make_task(content=\"Hello, what is 2+2?\", seed=42)\n            task_a = task_a.model_copy(update={\"temperature\": 0.0})\n            task_b = _make_task(\n                content=\"Write a short poem about the ocean and the sky.\",\n                seed=99,\n            )\n            task_b = task_b.model_copy(update={\"temperature\": 0.0})\n\n            # ── Concurrent: submit both to one ExoBatchGenerator ──\n            exo_gen = ExoBatchGenerator(\n                model=model,\n                tokenizer=tokenizer,\n                group=None,\n                kv_prefix_cache=None,\n            )\n\n            prompt_a = apply_chat_template(tokenizer=tokenizer, task_params=task_a)\n            prompt_b = apply_chat_template(tokenizer=tokenizer, task_params=task_b)\n            uid_a = exo_gen.submit(task_params=task_a, prompt=prompt_a)\n            uid_b = exo_gen.submit(task_params=task_b, prompt=prompt_b)\n\n            batch_tokens: dict[int, list[int]] = {uid_a: [], uid_b: []}\n            finished: set[int] = set()\n            while exo_gen.has_work:\n                results = exo_gen.step()\n                for uid, response in results:\n                    batch_tokens[uid].append(response.token)\n                    if response.finish_reason is not None:\n                        finished.add(uid)\n\n            exo_gen.close()\n\n            # ── Verify both completed ──\n            assert len(batch_tokens[uid_a]) > 0, \"No tokens for task A\"\n            assert len(batch_tokens[uid_b]) > 0, \"No tokens for task B\"\n            assert uid_a in finished, \"Task A never finished\"\n            assert uid_b in finished, \"Task B never finished\"\n        finally:\n            shutil.rmtree(tmpdir, ignore_errors=True)\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_mlx/test_distributed_fix.py",
    "content": "import multiprocessing as mp\nimport os\nfrom dataclasses import dataclass\nfrom typing import Any, Callable\n\nimport pytest\n\nfrom exo.worker.tests.unittests.test_mlx.conftest import (\n    DEFAULT_GPT_OSS_CONFIG,\n    create_hostfile,\n    run_gpt_oss_pipeline_device,\n    run_gpt_oss_tensor_parallel_device,\n)\n\n\ndef _check_model_exists() -> bool:\n    return DEFAULT_GPT_OSS_CONFIG.model_path.exists()\n\n\npytestmark = [\n    pytest.mark.slow,\n    pytest.mark.skipif(\n        not _check_model_exists(),\n        reason=f\"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}\",\n    ),\n]\n\n\n@dataclass\nclass DistributedTestResult:\n    timed_out: bool\n    world_size: int\n    results: dict[int, tuple[bool, str]]\n\n    @property\n    def all_success(self) -> bool:\n        if len(self.results) != self.world_size:\n            return False\n        return all(r[0] for r in self.results.values())\n\n\ndef run_distributed_test(\n    world_size: int,\n    port_offset: int,\n    process_timeout: int,\n    target: Callable[..., None],\n    make_args: Callable[[int], tuple[Any, ...]],\n) -> DistributedTestResult:\n    ctx = mp.get_context(\"spawn\")\n    hostfile_path, _ = create_hostfile(\n        world_size, DEFAULT_GPT_OSS_CONFIG.base_port + port_offset\n    )\n\n    try:\n        result_queue: Any = ctx.Queue()\n        processes: list[Any] = []\n\n        for rank in range(world_size):\n            args = make_args(rank)\n            p = ctx.Process(\n                target=target,\n                args=(rank, world_size, hostfile_path, *args, result_queue),\n            )\n            p.start()\n            processes.append(p)\n\n        for p in processes:  # pyright: ignore[reportAny]\n            p.join(timeout=process_timeout)  # pyright: ignore[reportAny]\n\n        timed_out = any(p.is_alive() for p in processes)  # pyright: ignore[reportAny]\n\n        for p in processes:  # pyright: ignore[reportAny]\n            if p.is_alive():  # pyright: ignore[reportAny]\n                p.terminate()  # pyright: ignore[reportAny]\n                p.join(timeout=5)  # pyright: ignore[reportAny]\n\n        results: dict[int, tuple[bool, str]] = {}\n        while not result_queue.empty():  # pyright: ignore[reportAny]\n            rank, success, value = result_queue.get()  # pyright: ignore[reportAny]\n            results[rank] = (success, value)\n\n        return DistributedTestResult(\n            timed_out=timed_out, world_size=world_size, results=results\n        )\n\n    finally:\n        os.unlink(hostfile_path)\n\n\ndef run_pipeline_test(\n    layer_splits: list[tuple[int, int]],\n    prompt_tokens: int,\n    prefill_step_size: int,\n    port_offset: int = 0,\n    process_timeout: int = 60,\n) -> DistributedTestResult:\n    def make_args(rank: int) -> tuple[Any, ...]:\n        return (\n            layer_splits,\n            prompt_tokens,\n            prefill_step_size,\n        )\n\n    return run_distributed_test(\n        world_size=len(layer_splits),\n        port_offset=port_offset,\n        process_timeout=process_timeout,\n        target=run_gpt_oss_pipeline_device,\n        make_args=make_args,\n    )\n\n\ndef run_tensor_test(\n    prompt_tokens: int,\n    prefill_step_size: int,\n    port_offset: int = 0,\n    process_timeout: int = 60,\n) -> DistributedTestResult:\n    def make_args(rank: int) -> tuple[Any, ...]:\n        return (\n            prompt_tokens,\n            prefill_step_size,\n        )\n\n    return run_distributed_test(\n        world_size=2,\n        port_offset=port_offset,\n        process_timeout=process_timeout,\n        target=run_gpt_oss_tensor_parallel_device,\n        make_args=make_args,\n    )\n\n\nclass TestPipelineParallelFix:\n    BUG_TRIGGER_SPLITS: list[tuple[int, int]] = [(0, 1), (1, 24)]\n\n    def test_pipeline_single_layer_first_device(self) -> None:\n        result = run_pipeline_test(\n            layer_splits=self.BUG_TRIGGER_SPLITS,\n            prompt_tokens=100,\n            prefill_step_size=64,\n            process_timeout=60,\n        )\n        assert not result.timed_out, \"Unexpected timeout - fix may not be working\"\n        assert result.all_success, f\"Failures: {result.results}\"\n\n\nclass TestPipelineSplitConfigurations:\n    @pytest.mark.parametrize(\n        \"layer_splits\",\n        [\n            [(0, 1), (1, 24)],\n            [(0, 6), (6, 24)],\n            [(0, 12), (12, 24)],\n        ],\n        ids=[\"1_23\", \"6_18\", \"12_12\"],\n    )\n    def test_pipeline_splits(\n        self,\n        layer_splits: list[tuple[int, int]],\n    ) -> None:\n        result = run_pipeline_test(\n            layer_splits=layer_splits,\n            prompt_tokens=600,\n            prefill_step_size=512,\n            port_offset=100,\n        )\n        assert not result.timed_out, f\"Timeout with {layer_splits}\"\n        assert result.all_success, f\"Failures with {layer_splits}: {result.results}\"\n\n\nclass TestPrefillStepSizeBoundaries:\n    @pytest.mark.parametrize(\n        \"prefill_step_size,prompt_tokens\",\n        [\n            (512, 511),\n            (512, 512),\n            (512, 513),\n            (512, 1024),\n        ],\n        ids=[\"under\", \"exact\", \"over\", \"double\"],\n    )\n    def test_boundary_conditions(\n        self,\n        prefill_step_size: int,\n        prompt_tokens: int,\n    ) -> None:\n        result = run_pipeline_test(\n            layer_splits=[(0, 12), (12, 24)],\n            prompt_tokens=prompt_tokens,\n            prefill_step_size=prefill_step_size,\n            port_offset=200,\n        )\n        assert not result.timed_out, f\"Timeout: {prompt_tokens=}, {prefill_step_size=}\"\n        assert result.all_success, f\"Failures: {result.results}\"\n\n\nclass TestTensorParallelFix:\n    def test_tensor_parallel(self) -> None:\n        result = run_tensor_test(\n            prompt_tokens=100,\n            prefill_step_size=64,\n            port_offset=400,\n        )\n        assert not result.timed_out, \"Unexpected timeout\"\n        assert result.all_success, f\"Failures: {result.results}\"\n\n\nclass TestTensorParallelBoundaries:\n    @pytest.mark.parametrize(\n        \"prefill_step_size,prompt_tokens\",\n        [\n            (512, 511),\n            (512, 512),\n            (512, 513),\n            (512, 1024),\n        ],\n        ids=[\"under\", \"exact\", \"over\", \"double\"],\n    )\n    def test_tensor_parallel_boundaries(\n        self,\n        prefill_step_size: int,\n        prompt_tokens: int,\n    ) -> None:\n        result = run_tensor_test(\n            prompt_tokens=prompt_tokens,\n            prefill_step_size=prefill_step_size,\n            port_offset=500,\n        )\n        assert not result.timed_out, f\"Timeout: {prompt_tokens=}, {prefill_step_size=}\"\n        assert result.all_success, f\"Failures: {result.results}\"\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_mlx/test_kv_prefix_cache.py",
    "content": "# type: ignore\nimport time\nfrom typing import cast\nfrom unittest.mock import patch\n\nimport mlx.core as mx\nimport pytest\nfrom mlx_lm.models.cache import KVCache\nfrom mlx_lm.sample_utils import make_sampler\n\nfrom exo.shared.types.common import ModelId\nfrom exo.shared.types.mlx import Model\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\nfrom exo.worker.engines.mlx.cache import (\n    KVPrefixCache,\n    cache_length,\n    encode_prompt,\n    get_prefix_length,\n    make_kv_cache,\n)\nfrom exo.worker.engines.mlx.generator.generate import mlx_generate, prefill\nfrom exo.worker.engines.mlx.utils_mlx import apply_chat_template\nfrom exo.worker.tests.unittests.test_mlx.conftest import (\n    DEFAULT_GPT_OSS_CONFIG,\n    DEFAULT_GPT_OSS_MODEL_ID,\n)\n\n\ndef _check_model_exists() -> bool:\n    return DEFAULT_GPT_OSS_CONFIG.model_path.exists()\n\n\nclass TestGetPrefixLength:\n    def test_identical_arrays(self):\n        a = mx.array([1, 2, 3, 4, 5])\n        b = mx.array([1, 2, 3, 4, 5])\n        assert get_prefix_length(a, b) == 5\n\n    def test_no_common_prefix(self):\n        a = mx.array([1, 2, 3])\n        b = mx.array([4, 5, 6])\n        assert get_prefix_length(a, b) == 0\n\n    def test_partial_prefix(self):\n        a = mx.array([1, 2, 3, 4, 5])\n        b = mx.array([1, 2, 3, 7, 8])\n        assert get_prefix_length(a, b) == 3\n\n    def test_prompt_longer_than_cached(self):\n        a = mx.array([1, 2, 3, 4, 5])\n        b = mx.array([1, 2, 3])\n        assert get_prefix_length(a, b) == 3\n\n    def test_cached_longer_than_prompt(self):\n        a = mx.array([1, 2, 3])\n        b = mx.array([1, 2, 3, 4, 5])\n        assert get_prefix_length(a, b) == 3\n\n    def test_single_token_match(self):\n        a = mx.array([1, 2, 3])\n        b = mx.array([1, 5, 6])\n        assert get_prefix_length(a, b) == 1\n\n    def test_empty_prompt(self):\n        a = mx.array([]).astype(mx.int32)\n        b = mx.array([1, 2, 3])\n        assert get_prefix_length(a, b) == 0\n\n    def test_empty_cached(self):\n        a = mx.array([1, 2, 3])\n        b = mx.array([]).astype(mx.int32)\n        assert get_prefix_length(a, b) == 0\n\n    def test_both_empty(self):\n        a = mx.array([]).astype(mx.int32)\n        b = mx.array([]).astype(mx.int32)\n        assert get_prefix_length(a, b) == 0\n\n\nclass TestKVPrefix:\n    @pytest.fixture\n    def mock_tokenizer(self):\n        \"\"\"Create a minimal mock tokenizer for tests that don't need real tokenization.\"\"\"\n        from unittest.mock import MagicMock\n\n        tokenizer = MagicMock()\n        tokenizer.encode.return_value = [1, 2, 3]\n        return tokenizer\n\n    def test_starts_empty(self, mock_tokenizer):\n        cache = KVPrefixCache(None)\n        assert len(cache.prompts) == 0\n        assert len(cache.caches) == 0\n\n    def test_clear_empties_cache(self, mock_tokenizer):\n        cache = KVPrefixCache(None)\n        cache.prompts.append(mx.array([1, 2, 3]))\n        cache.caches.append([KVCache()])\n        cache.clear()\n        assert len(cache.prompts) == 0\n        assert len(cache.caches) == 0\n\n    def test_clear_on_empty_cache(self, mock_tokenizer):\n        cache = KVPrefixCache(None)\n        cache.clear()\n        assert len(cache.prompts) == 0\n\n\ndef _load_gpt_oss() -> tuple[Model, object]:\n    from mlx_lm.utils import load_model\n\n    from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id\n\n    model_path = DEFAULT_GPT_OSS_CONFIG.model_path\n    model_id = ModelId(DEFAULT_GPT_OSS_MODEL_ID)\n\n    model, _ = load_model(model_path, lazy=False)\n    tokenizer = load_tokenizer_for_model_id(model_id, model_path)\n    return cast(Model, model), tokenizer\n\n\n@pytest.mark.slow\n@pytest.mark.skipif(\n    not _check_model_exists(),\n    reason=f\"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}\",\n)\nclass TestKVPrefixCacheWithModel:\n    @pytest.fixture(scope=\"class\")\n    def model_and_tokenizer(self):\n        model, tokenizer = _load_gpt_oss()\n        return model, tokenizer\n\n    def test_prefill_populates_cache(self, model_and_tokenizer):\n        model, tokenizer = model_and_tokenizer\n\n        task = TextGenerationTaskParams(\n            model=DEFAULT_GPT_OSS_MODEL_ID,\n            input=[InputMessage(role=\"user\", content=\"Hello!!\")],\n            max_output_tokens=1,\n        )\n        prompt = apply_chat_template(tokenizer, task)\n        tokens = encode_prompt(tokenizer, prompt)\n        cache = make_kv_cache(model)\n\n        _, _, snapshots = prefill(\n            model,\n            tokenizer,\n            make_sampler(0.0),\n            tokens,\n            cache,\n            group=None,\n            on_prefill_progress=None,\n            distributed_prompt_progress_callback=None,\n        )\n\n        # Cache should now hold the prompt tokens minus one\n        assert cache_length(cache) == len(tokens) - 1\n        # Snapshots should be available for models with non-KV caches\n        assert len(snapshots) > 0\n\n    def test_add_and_get_exact_match(self, model_and_tokenizer):\n        model, tokenizer = model_and_tokenizer\n\n        task = TextGenerationTaskParams(\n            model=DEFAULT_GPT_OSS_MODEL_ID,\n            input=[InputMessage(role=\"user\", content=\"Test exact\")],\n            max_output_tokens=1,\n        )\n        prompt = apply_chat_template(tokenizer, task)\n        tokens = encode_prompt(tokenizer, prompt)\n        cache = make_kv_cache(model)\n\n        _, _, snapshots = prefill(\n            model,\n            tokenizer,\n            make_sampler(0.0),\n            tokens,\n            cache,\n            group=None,\n            on_prefill_progress=None,\n            distributed_prompt_progress_callback=None,\n        )\n\n        kv_prefix_cache = KVPrefixCache(None)\n        kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)\n\n        assert len(kv_prefix_cache.prompts) == 1\n        stored_length = cache_length(kv_prefix_cache.caches[0])\n        assert stored_length > 0\n\n        # Retrieve with same prompt: exact match\n        result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(\n            model, tokens\n        )\n        assert matched_index == 0\n\n        # Exact match returns last token(s) — for models with SSM/rotating caches,\n        # snapshot availability constrains how far back we can trim, so remaining\n        # may be 1 or 2 tokens depending on the model.\n        assert len(remaining_tokens) >= 1\n        assert mx.array_equal(remaining_tokens, tokens[-len(remaining_tokens) :])\n\n    def test_add_and_get_prefix_match(self, model_and_tokenizer):\n        \"\"\"get_kv_cache with a longer prompt sharing prefix should return partial match.\"\"\"\n        model, tokenizer = model_and_tokenizer\n\n        short_task = TextGenerationTaskParams(\n            model=DEFAULT_GPT_OSS_MODEL_ID,\n            input=[InputMessage(role=\"user\", content=\"Hi\")],\n            max_output_tokens=1,\n        )\n        short_prompt = apply_chat_template(tokenizer, short_task)\n        short_tokens = encode_prompt(tokenizer, short_prompt)\n        cache = make_kv_cache(model)\n\n        _, _, snapshots = prefill(\n            model,\n            tokenizer,\n            make_sampler(0.0),\n            short_tokens,\n            cache,\n            group=None,\n            on_prefill_progress=None,\n            distributed_prompt_progress_callback=None,\n        )\n\n        kv_prefix_cache = KVPrefixCache(None)\n        kv_prefix_cache.add_kv_cache(short_tokens, cache, snapshots)\n\n        # Query with longer prompt that shares the chat template prefix\n        long_task = TextGenerationTaskParams(\n            model=DEFAULT_GPT_OSS_MODEL_ID,\n            input=[InputMessage(role=\"user\", content=\"Hi there, how are you?\")],\n            max_output_tokens=1,\n        )\n        long_prompt = apply_chat_template(tokenizer, long_task)\n        long_tokens = encode_prompt(tokenizer, long_prompt)\n\n        # The prompts share a prefix (chat template preamble + \"Hi\")\n        expected_prefix = get_prefix_length(long_tokens, short_tokens)\n        assert expected_prefix > 0, (\n            \"Prompts should share a prefix from the chat template\"\n        )\n\n        result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(\n            model, long_tokens\n        )\n        assert matched_index == 0\n\n        # remaining_tokens covers from snapshot restore position to end\n        assert len(remaining_tokens) >= len(long_tokens) - expected_prefix\n\n    def test_stored_cache_not_mutated_after_get_and_generation(\n        self, model_and_tokenizer\n    ):\n        \"\"\"Getting a cache and then mutating it (as generation does) must not corrupt stored cache.\"\"\"\n        model, tokenizer = model_and_tokenizer\n\n        task = TextGenerationTaskParams(\n            model=DEFAULT_GPT_OSS_MODEL_ID,\n            input=[InputMessage(role=\"user\", content=\"Mutation test\")],\n            max_output_tokens=1,\n        )\n        prompt = apply_chat_template(tokenizer, task)\n        tokens = encode_prompt(tokenizer, prompt)\n        cache = make_kv_cache(model)\n\n        _, _, snapshots = prefill(\n            model,\n            tokenizer,\n            make_sampler(0.0),\n            tokens,\n            cache,\n            group=None,\n            on_prefill_progress=None,\n            distributed_prompt_progress_callback=None,\n        )\n\n        kv_prefix_cache = KVPrefixCache(None)\n        kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)\n\n        stored_length = cache_length(kv_prefix_cache.caches[0])\n\n        # Get cache and mutate it (simulating what generation does)\n        result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, tokens)\n        assert matched_index == 0\n\n        # Simulate generation: feed many additional tokens through the cache\n        head_dim = result_cache[0].keys.shape[-1]\n        num_heads = result_cache[0].keys.shape[1]\n        extra_keys = mx.random.normal((1, num_heads, 50, head_dim))\n        extra_values = mx.random.normal((1, num_heads, 50, head_dim))\n        for layer_cache in result_cache:\n            layer_cache.update_and_fetch(extra_keys, extra_values)\n        mx.eval([c.keys for c in result_cache])\n\n        # Stored cache must be unchanged\n        assert cache_length(kv_prefix_cache.caches[0]) == stored_length\n\n    def test_stored_cache_survives_repeated_get_mutate_cycles(\n        self, model_and_tokenizer\n    ):\n        \"\"\"Multiple get+mutate cycles (like repeated user requests) must not corrupt cache.\"\"\"\n        model, tokenizer = model_and_tokenizer\n\n        task = TextGenerationTaskParams(\n            model=DEFAULT_GPT_OSS_MODEL_ID,\n            input=[InputMessage(role=\"user\", content=\"Repeat test\")],\n            max_output_tokens=1,\n        )\n        prompt = apply_chat_template(tokenizer, task)\n        tokens = encode_prompt(tokenizer, prompt)\n        cache = make_kv_cache(model)\n\n        _, _, snapshots = prefill(\n            model,\n            tokenizer,\n            make_sampler(0.0),\n            tokens,\n            cache,\n            group=None,\n            on_prefill_progress=None,\n            distributed_prompt_progress_callback=None,\n        )\n\n        kv_prefix_cache = KVPrefixCache(None)\n        kv_prefix_cache.add_kv_cache(tokens, cache, snapshots)\n\n        stored_length = cache_length(kv_prefix_cache.caches[0])\n\n        for i in range(3):\n            result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, tokens)\n\n            head_dim = result_cache[0].keys.shape[-1]\n            num_heads = result_cache[0].keys.shape[1]\n            extra = mx.random.normal((1, num_heads, 30, head_dim))\n            for layer_cache in result_cache:\n                layer_cache.update_and_fetch(extra, extra)\n            mx.eval([c.keys for c in result_cache])\n\n            assert cache_length(kv_prefix_cache.caches[0]) == stored_length, (\n                f\"Failed on loop {i}\"\n            )\n\n    def test_mlx_generate_populates_cache(self, model_and_tokenizer):\n        \"\"\"mlx_generate should save the cache after generation completes.\"\"\"\n        model, tokenizer = model_and_tokenizer\n\n        kv_prefix_cache = KVPrefixCache(None)\n        task = TextGenerationTaskParams(\n            model=DEFAULT_GPT_OSS_MODEL_ID,\n            input=[InputMessage(role=\"user\", content=\"Hello\")],\n            max_output_tokens=5,\n        )\n        prompt = apply_chat_template(tokenizer, task)\n        prompt_tokens = encode_prompt(tokenizer, prompt)\n\n        # Consume the entire generator so the cache-saving code after yield runs\n        generated_tokens = 0\n        for _response in mlx_generate(\n            model=model,\n            tokenizer=tokenizer,\n            task=task,\n            prompt=prompt,\n            kv_prefix_cache=kv_prefix_cache,\n            group=None,\n        ):\n            generated_tokens += 1\n\n        assert len(kv_prefix_cache.prompts) == 1\n        assert len(kv_prefix_cache.caches) == 1\n        # Cache should contain prompt + generated tokens\n        expected_length = len(prompt_tokens) + generated_tokens\n        assert cache_length(kv_prefix_cache.caches[0]) == expected_length\n\n    def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):\n        \"\"\"Second mlx_generate call with same prompt should get a prefix hit from stored cache.\"\"\"\n        model, tokenizer = model_and_tokenizer\n\n        kv_prefix_cache = KVPrefixCache(None)\n        task = TextGenerationTaskParams(\n            model=DEFAULT_GPT_OSS_MODEL_ID,\n            input=[InputMessage(role=\"user\", content=\"Reuse test\")],\n            max_output_tokens=5,\n        )\n        prompt = apply_chat_template(tokenizer, task)\n        prompt_tokens = encode_prompt(tokenizer, prompt)\n\n        # First generation populates cache\n        for _response in mlx_generate(\n            model=model,\n            tokenizer=tokenizer,\n            task=task,\n            prompt=prompt,\n            kv_prefix_cache=kv_prefix_cache,\n            group=None,\n        ):\n            pass\n\n        assert len(kv_prefix_cache.prompts) == 1\n\n        # Second call should find a prefix match (the stored cache contains\n        # prompt + generated tokens, which shares the prompt prefix)\n        result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(\n            model, prompt_tokens\n        )\n        # The stored cache is longer than the prompt (it includes generated tokens),\n        # so this is a prefix match where our prompt is fully contained\n        assert matched_index == 0\n        # Exact match: remaining_tokens is just the last token and the one before\n        assert len(remaining_tokens) == 2\n        assert mx.array_equal(remaining_tokens, prompt_tokens[-2:])\n\n    def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):\n        \"\"\"With a prompt > 1000 tokens, second generation should update the cache entry in-place.\"\"\"\n        model, tokenizer = model_and_tokenizer\n\n        kv_prefix_cache = KVPrefixCache(None)\n\n        # Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE\n        base_text = \"The quick brown fox jumps over the lazy dog. \"\n        base_tokens = tokenizer.encode(base_text)\n        repeats = (1200 // len(base_tokens)) + 2\n        long_content = base_text * repeats\n\n        task1 = TextGenerationTaskParams(\n            model=DEFAULT_GPT_OSS_MODEL_ID,\n            input=[InputMessage(role=\"user\", content=long_content)],\n            max_output_tokens=5,\n        )\n        prompt1 = apply_chat_template(tokenizer, task1)\n        prompt1_tokens = encode_prompt(tokenizer, prompt1)\n        assert len(prompt1_tokens) > 1000, (\n            \"Prompt must exceed _MIN_PREFIX_HIT_TO_UPDATE\"\n        )\n\n        # First generation populates the cache (must prefill all tokens)\n        t0 = time.perf_counter()\n        for _response in mlx_generate(\n            model=model,\n            tokenizer=tokenizer,\n            task=task1,\n            prompt=prompt1,\n            kv_prefix_cache=kv_prefix_cache,\n            group=None,\n        ):\n            pass\n        first_gen_time = time.perf_counter() - t0\n\n        assert len(kv_prefix_cache.prompts) == 1\n        first_cache_length = cache_length(kv_prefix_cache.caches[0])\n\n        # Second generation: same long prompt + extra content (simulating multi-turn)\n        task2 = TextGenerationTaskParams(\n            model=DEFAULT_GPT_OSS_MODEL_ID,\n            input=[\n                InputMessage(role=\"user\", content=long_content),\n                InputMessage(role=\"assistant\", content=\"Sure, I can help.\"),\n                InputMessage(role=\"user\", content=\"Tell me more.\"),\n            ],\n            max_output_tokens=5,\n        )\n        prompt2 = apply_chat_template(tokenizer, task2)\n        prompt2_tokens = encode_prompt(tokenizer, prompt2)\n\n        # Verify the prompts share a long prefix\n        prefix_len = get_prefix_length(prompt2_tokens, prompt1_tokens)\n        assert prefix_len > 1000, \"Prompts must share > 1000 token prefix\"\n\n        # Second generation should reuse the cached prefix (only prefill new tokens)\n        t0 = time.perf_counter()\n        for _response in mlx_generate(\n            model=model,\n            tokenizer=tokenizer,\n            task=task2,\n            prompt=prompt2,\n            kv_prefix_cache=kv_prefix_cache,\n            group=None,\n        ):\n            pass\n        second_gen_time = time.perf_counter() - t0\n\n        # Second generation should be significantly faster due to prefix cache hit - hopefully not flaky\n        assert second_gen_time < first_gen_time * 0.5, (\n            f\"Expected prefix cache speedup: \"\n            f\"first={first_gen_time:.2f}s, second={second_gen_time:.2f}s\"\n        )\n\n        # With prefix_hit > 1000, should update in-place (not add a second entry)\n        assert len(kv_prefix_cache.prompts) == 1\n        # Updated cache should be longer (prompt2 + generated > prompt1 + generated)\n        updated_cache_length = cache_length(kv_prefix_cache.caches[0])\n        assert updated_cache_length > first_cache_length\n\n    def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):\n        \"\"\"After mlx_generate saves a cache, a second generation must not corrupt the stored copy.\"\"\"\n        model, tokenizer = model_and_tokenizer\n\n        kv_prefix_cache = KVPrefixCache(None)\n        task = TextGenerationTaskParams(\n            model=DEFAULT_GPT_OSS_MODEL_ID,\n            input=[InputMessage(role=\"user\", content=\"Immutable test\")],\n            max_output_tokens=5,\n        )\n        prompt = apply_chat_template(tokenizer, task)\n\n        # First generation populates cache\n        for _response in mlx_generate(\n            model=model,\n            tokenizer=tokenizer,\n            task=task,\n            prompt=prompt,\n            kv_prefix_cache=kv_prefix_cache,\n            group=None,\n        ):\n            pass\n\n        firstcache_length = cache_length(kv_prefix_cache.caches[0])\n\n        # Second generation gets the cache and mutates it during generation\n        for _response in mlx_generate(\n            model=model,\n            tokenizer=tokenizer,\n            task=task,\n            prompt=prompt,\n            kv_prefix_cache=kv_prefix_cache,\n            group=None,\n        ):\n            pass\n\n        # The first stored cache must not have been mutated by the second generation\n        assert cache_length(kv_prefix_cache.caches[0]) == firstcache_length\n\n    def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):\n        \"\"\"Under memory pressure, adding a new cache entry evicts the least recently used one.\"\"\"\n        model, tokenizer = model_and_tokenizer\n\n        kv_prefix_cache = KVPrefixCache(None)\n\n        # Add three cache entries with different prompts\n        prompts = [\"First entry\", \"Second entry\", \"Third entry\"]\n        for i, content in enumerate(prompts):\n            task = TextGenerationTaskParams(\n                model=DEFAULT_GPT_OSS_MODEL_ID,\n                input=[InputMessage(role=\"user\", content=content)],\n                max_output_tokens=1,\n            )\n            prompt = apply_chat_template(tokenizer, task)\n            tokens = encode_prompt(tokenizer, prompt)\n            cache = make_kv_cache(model)\n            prefill(\n                model,\n                tokenizer,\n                make_sampler(0.0),\n                tokens,\n                cache,\n                group=None,\n                on_prefill_progress=None,\n                distributed_prompt_progress_callback=None,\n            )\n            kv_prefix_cache.add_kv_cache(tokens, cache)\n            # Stagger _last_used so LRU order is deterministic\n            kv_prefix_cache._last_used[i] = float(i)\n\n        assert len(kv_prefix_cache.prompts) == 3\n\n        # Access the third entry to make it most recently used\n        kv_prefix_cache._last_used[2] = 100.0\n        # Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next\n\n        # Simulate memory pressure: return usage above _MEMORY_THRESHOLD (0.9)\n        with patch(\n            \"exo.worker.engines.mlx.cache.get_memory_used_percentage\",\n            return_value=0.95,\n        ):\n            # Trigger eviction by adding a new entry\n            task = TextGenerationTaskParams(\n                model=DEFAULT_GPT_OSS_MODEL_ID,\n                input=[InputMessage(role=\"user\", content=\"New entry\")],\n                max_output_tokens=1,\n            )\n            prompt = apply_chat_template(tokenizer, task)\n            tokens = encode_prompt(tokenizer, prompt)\n            cache = make_kv_cache(model)\n            prefill(\n                model,\n                tokenizer,\n                make_sampler(0.0),\n                tokens,\n                cache,\n                group=None,\n                on_prefill_progress=None,\n                distributed_prompt_progress_callback=None,\n            )\n            kv_prefix_cache.add_kv_cache(tokens, cache)\n\n        # LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)\n        # Since fake_active stays above threshold after each eviction (we don't change it),\n        # all old entries get evicted, leaving only the newly added one\n        assert len(kv_prefix_cache.prompts) == 1\n        # The surviving entry should be the newly added one\n        assert get_prefix_length(kv_prefix_cache.prompts[0], tokens) == len(tokens)\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_mlx/test_pipeline_prefill_callbacks.py",
    "content": "# type: ignore\n\"\"\"Test that pipeline prefill callbacks and output exactly match stream_generate.\n\nSpins up a single-device (non-pipeline) run and a distributed pipeline run,\nthen verifies that the prompt_progress_callback sequences are identical\nand that generated text matches.\n\"\"\"\n\nimport json\nimport multiprocessing as mp\nimport os\nimport tempfile\nimport traceback\nfrom typing import Any, cast\n\nimport pytest\n\nfrom exo.shared.constants import EXO_MODELS_DIR\nfrom exo.shared.models.model_cards import ModelCard, ModelTask\nfrom exo.shared.types.common import ModelId\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\n\nMODEL_ID = \"mlx-community/gpt-oss-20b-MXFP4-Q8\"\nMODEL_PATH = EXO_MODELS_DIR / \"mlx-community--gpt-oss-20b-MXFP4-Q8\"\nTOTAL_LAYERS = 24\nMAX_TOKENS = 10\nSEED = 42\nTEMPERATURE = 0.0\n\n\ndef _model_card() -> ModelCard:\n    return ModelCard(\n        model_id=ModelId(MODEL_ID),\n        storage_size=Memory.from_gb(12),\n        n_layers=TOTAL_LAYERS,\n        hidden_size=2880,\n        supports_tensor=False,\n        tasks=[ModelTask.TextGeneration],\n    )\n\n\ndef _build_prompt(tokenizer: Any, prompt_tokens: int) -> tuple[str, Any]:\n    \"\"\"Build a prompt with the given number of user-content tokens, return (chat_prompt, task).\"\"\"\n    from exo.worker.engines.mlx.utils_mlx import apply_chat_template\n\n    base_text = \"The quick brown fox jumps over the lazy dog. \"\n    base_toks = tokenizer.encode(base_text)\n    repeats = (prompt_tokens // len(base_toks)) + 2\n    long_text = base_text * repeats\n    tokens = tokenizer.encode(long_text)[:prompt_tokens]\n    prompt_text = tokenizer.decode(tokens)\n\n    task = TextGenerationTaskParams(\n        model=MODEL_ID,\n        input=[InputMessage(role=\"user\", content=prompt_text)],\n        max_output_tokens=MAX_TOKENS,\n        temperature=TEMPERATURE,\n        seed=SEED,\n    )\n\n    prompt = apply_chat_template(tokenizer, task)\n    return prompt, task\n\n\n# ---------------------------------------------------------------------------\n# Single-device process: uses stream_generate path (no pipeline layers)\n# ---------------------------------------------------------------------------\ndef _run_single_device(\n    prompt_tokens: int,\n    result_queue: Any,\n) -> None:\n    \"\"\"Load full model without pipeline sharding, run mlx_generate, record callbacks.\"\"\"\n    try:\n        import mlx.core as mx\n        from mlx_lm.utils import load_model\n\n        from exo.shared.types.worker.shards import PipelineShardMetadata\n        from exo.worker.engines.mlx.cache import encode_prompt\n        from exo.worker.engines.mlx.generator.generate import mlx_generate\n        from exo.worker.engines.mlx.utils_mlx import (\n            build_model_path,\n            get_tokenizer,\n        )\n\n        model_path = build_model_path(ModelId(MODEL_ID))\n        model, _ = load_model(model_path, lazy=True, strict=False)\n        mx.eval(model)\n\n        # Use PipelineShardMetadata just for get_tokenizer (needs model_card), but\n        # do NOT apply pipeline sharding — the model keeps all layers unwrapped.\n        dummy_meta = PipelineShardMetadata(\n            model_card=_model_card(),\n            device_rank=0,\n            world_size=1,\n            start_layer=0,\n            end_layer=TOTAL_LAYERS,\n            n_layers=TOTAL_LAYERS,\n        )\n        tokenizer = get_tokenizer(model_path, dummy_meta)\n\n        prompt, task = _build_prompt(tokenizer, prompt_tokens)\n\n        callbacks: list[tuple[int, int]] = []\n\n        def on_progress(processed: int, total: int) -> None:\n            callbacks.append((processed, total))\n\n        generated_text = \"\"\n        for response in mlx_generate(\n            model=model,\n            tokenizer=tokenizer,\n            task=task,\n            prompt=prompt,\n            kv_prefix_cache=None,\n            group=None,\n            on_prefill_progress=on_progress,\n        ):\n            generated_text += response.text\n            if response.finish_reason is not None:\n                break\n\n        # Also record the token count that prefill() received (prompt_tokens[:-1])\n        all_tokens = encode_prompt(tokenizer, prompt)\n        prefill_token_count = len(all_tokens) - 1\n\n        result_queue.put(\n            (\n                True,\n                {\n                    \"callbacks\": callbacks,\n                    \"text\": generated_text,\n                    \"prefill_token_count\": prefill_token_count,\n                },\n            )\n        )\n\n    except Exception as e:\n        result_queue.put((False, f\"{e}\\n{traceback.format_exc()}\"))\n\n\n# ---------------------------------------------------------------------------\n# Pipeline device process: uses _pipeline_prefill_cache path\n# ---------------------------------------------------------------------------\ndef _run_pipeline_device(\n    rank: int,\n    world_size: int,\n    hostfile_path: str,\n    layer_splits: list[tuple[int, int]],\n    prompt_tokens: int,\n    result_queue: Any,\n) -> None:\n    \"\"\"Load model with pipeline sharding, run mlx_generate, record callbacks.\"\"\"\n    os.environ[\"MLX_HOSTFILE\"] = hostfile_path\n    os.environ[\"MLX_RANK\"] = str(rank)\n\n    try:\n        import mlx.core as mx\n\n        from exo.shared.types.worker.shards import PipelineShardMetadata\n        from exo.worker.engines.mlx.cache import encode_prompt\n        from exo.worker.engines.mlx.generator.generate import mlx_generate\n        from exo.worker.engines.mlx.utils_mlx import shard_and_load\n\n        group = mx.distributed.init(backend=\"ring\", strict=True)\n\n        start_layer, end_layer = layer_splits[rank]\n        shard_meta = PipelineShardMetadata(\n            model_card=_model_card(),\n            device_rank=rank,\n            world_size=world_size,\n            start_layer=start_layer,\n            end_layer=end_layer,\n            n_layers=TOTAL_LAYERS,\n        )\n\n        model, tokenizer = shard_and_load(\n            shard_meta, group, on_timeout=None, on_layer_loaded=None\n        )\n        model = cast(Any, model)\n\n        prompt, task = _build_prompt(tokenizer, prompt_tokens)\n\n        callbacks: list[tuple[int, int]] = []\n\n        def on_progress(processed: int, total: int) -> None:\n            callbacks.append((processed, total))\n\n        def distributed_prompt_progress_callback(_group: Any = group) -> None:\n            from exo.worker.engines.mlx.utils_mlx import mx_any\n\n            mx_any(False, _group)\n\n        generated_text = \"\"\n        for response in mlx_generate(\n            model=model,\n            tokenizer=tokenizer,\n            task=task,\n            prompt=prompt,\n            kv_prefix_cache=None,\n            group=group,\n            on_prefill_progress=on_progress,\n            distributed_prompt_progress_callback=distributed_prompt_progress_callback,\n        ):\n            generated_text += response.text\n            if response.finish_reason is not None:\n                break\n\n        all_tokens = encode_prompt(tokenizer, prompt)\n        prefill_token_count = len(all_tokens) - 1\n\n        result_queue.put(\n            (\n                rank,\n                True,\n                {\n                    \"callbacks\": callbacks,\n                    \"text\": generated_text,\n                    \"prefill_token_count\": prefill_token_count,\n                },\n            )\n        )\n\n    except Exception as e:\n        result_queue.put((rank, False, f\"{e}\\n{traceback.format_exc()}\"))\n\n\n# ---------------------------------------------------------------------------\n# Test helpers\n# ---------------------------------------------------------------------------\ndef _create_hostfile(world_size: int, base_port: int) -> str:\n    hosts = [f\"127.0.0.1:{base_port + i}\" for i in range(world_size)]\n    with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".json\", delete=False) as f:\n        json.dump(hosts, f)\n        return f.name\n\n\ndef _run_single_device_test(prompt_tokens: int, timeout: int = 120) -> dict[str, Any]:\n    \"\"\"Run single-device (stream_generate) prefill and return results.\"\"\"\n    ctx = mp.get_context(\"spawn\")\n    result_queue: Any = ctx.Queue()\n\n    p = ctx.Process(target=_run_single_device, args=(prompt_tokens, result_queue))\n    p.start()\n    p.join(timeout=timeout)\n\n    if p.is_alive():\n        p.terminate()\n        p.join(timeout=5)\n        pytest.fail(\"Single-device process timed out\")\n\n    assert not result_queue.empty(), \"Single-device process produced no result\"\n    success, data = result_queue.get()\n    assert success, f\"Single-device process failed:\\n{data}\"\n    return data\n\n\ndef _run_pipeline_test(\n    layer_splits: list[tuple[int, int]],\n    prompt_tokens: int,\n    base_port: int,\n    timeout: int = 120,\n) -> dict[int, dict[str, Any]]:\n    \"\"\"Run pipeline prefill across ranks and return per-rank results.\"\"\"\n    world_size = len(layer_splits)\n    hostfile_path = _create_hostfile(world_size, base_port)\n    ctx = mp.get_context(\"spawn\")\n    result_queue: Any = ctx.Queue()\n\n    try:\n        processes: list[Any] = []\n        for rank in range(world_size):\n            p = ctx.Process(\n                target=_run_pipeline_device,\n                args=(\n                    rank,\n                    world_size,\n                    hostfile_path,\n                    layer_splits,\n                    prompt_tokens,\n                    result_queue,\n                ),\n            )\n            p.start()\n            processes.append(p)\n\n        for p in processes:\n            p.join(timeout=timeout)\n\n        timed_out = any(p.is_alive() for p in processes)\n        for p in processes:\n            if p.is_alive():\n                p.terminate()\n                p.join(timeout=5)\n\n        assert not timed_out, \"Pipeline processes timed out\"\n\n        results: dict[int, dict[str, Any]] = {}\n        while not result_queue.empty():\n            rank, success, data = result_queue.get()\n            assert success, f\"Pipeline rank {rank} failed:\\n{data}\"\n            results[rank] = data\n\n        assert len(results) == world_size, (\n            f\"Expected {world_size} results, got {len(results)}: missing ranks {set(range(world_size)) - results.keys()}\"\n        )\n        return results\n\n    finally:\n        os.unlink(hostfile_path)\n\n\n# ---------------------------------------------------------------------------\n# Tests\n# ---------------------------------------------------------------------------\npytestmark = [\n    pytest.mark.slow,\n    pytest.mark.skipif(\n        not MODEL_PATH.exists(),\n        reason=f\"GPT-OSS model not found at {MODEL_PATH}\",\n    ),\n]\n\nLAYER_SPLITS_4WAY: list[tuple[int, int]] = [(0, 6), (6, 12), (12, 18), (18, 24)]\nLAYER_SPLITS_2WAY: list[tuple[int, int]] = [(0, 12), (12, 24)]\n\n\nclass TestPipelineNoDeadlock:\n    \"\"\"Pipeline prefill must not deadlock at any rank count or prompt length.\"\"\"\n\n    @pytest.mark.parametrize(\n        \"layer_splits,prompt_tokens\",\n        [\n            (LAYER_SPLITS_2WAY, 128),\n            (LAYER_SPLITS_2WAY, 4096),\n            (LAYER_SPLITS_2WAY, 8192),\n            (LAYER_SPLITS_2WAY, 16384),\n            (LAYER_SPLITS_4WAY, 128),\n            (LAYER_SPLITS_4WAY, 4096),\n            (LAYER_SPLITS_4WAY, 8192),\n            (LAYER_SPLITS_4WAY, 16384),\n        ],\n        ids=[\n            \"2rank_128tok\",\n            \"2rank_4096tok\",\n            \"2rank_8192tok\",\n            \"2rank_16384tok\",\n            \"4rank_128tok\",\n            \"4rank_4096tok\",\n            \"4rank_8192tok\",\n            \"4rank_16384tok\",\n        ],\n    )\n    def test_no_deadlock(\n        self,\n        layer_splits: list[tuple[int, int]],\n        prompt_tokens: int,\n    ) -> None:\n        \"\"\"Pipeline must complete without deadlock at various prompt lengths.\"\"\"\n        pipeline_results = _run_pipeline_test(\n            layer_splits=layer_splits,\n            prompt_tokens=prompt_tokens,\n            base_port=29650,\n            timeout=60,\n        )\n        # If we get here, no deadlock. Verify all ranks produced output.\n        for rank, pipe_data in sorted(pipeline_results.items()):\n            assert pipe_data[\"text\"], f\"Rank {rank} produced no output text\"\n\n\nclass TestPipelinePrefillCallbacks:\n    \"\"\"Verify that pipeline prefill callbacks exactly match stream_generate callbacks.\"\"\"\n\n    @pytest.mark.parametrize(\n        \"prompt_tokens\",\n        [50, 500, 5000],\n        ids=[\"short_50\", \"medium_500\", \"long_5000\"],\n    )\n    def test_callbacks_match(self, prompt_tokens: int) -> None:\n        \"\"\"All pipeline ranks must produce identical callback sequences.\"\"\"\n        # Run 4-rank pipeline\n        pipeline_results = _run_pipeline_test(\n            layer_splits=LAYER_SPLITS_4WAY,\n            prompt_tokens=prompt_tokens,\n            base_port=29700,\n            timeout=180,\n        )\n\n        # All ranks must agree on prefill token count and callback sequence\n        rank0_data = pipeline_results[0]\n        rank0_callbacks = rank0_data[\"callbacks\"]\n        prefill_count = rank0_data[\"prefill_token_count\"]\n\n        for rank, pipe_data in sorted(pipeline_results.items()):\n            pipe_callbacks = pipe_data[\"callbacks\"]\n\n            assert pipe_data[\"prefill_token_count\"] == prefill_count, (\n                f\"Rank {rank} prefill token count mismatch: \"\n                f\"{pipe_data['prefill_token_count']} vs {prefill_count}\"\n            )\n\n            assert pipe_callbacks == rank0_callbacks, (\n                f\"Rank {rank} callback mismatch for {prompt_tokens} prompt tokens \"\n                f\"(prefill M={prefill_count}):\\n\"\n                f\"  pipeline R0 ({len(rank0_callbacks)} callbacks): {rank0_callbacks}\\n\"\n                f\"  pipeline R{rank} ({len(pipe_callbacks)} callbacks): {pipe_callbacks}\"\n            )\n\n        # Structural checks: starts with (0, M), ends with (M, M), monotonically increasing\n        assert rank0_callbacks[0] == (0, prefill_count), (\n            f\"First callback should be (0, {prefill_count}), got {rank0_callbacks[0]}\"\n        )\n        assert rank0_callbacks[-1] == (prefill_count, prefill_count), (\n            f\"Last callback should be ({prefill_count}, {prefill_count}), got {rank0_callbacks[-1]}\"\n        )\n        for i in range(1, len(rank0_callbacks)):\n            assert rank0_callbacks[i][0] >= rank0_callbacks[i - 1][0], (\n                f\"Callbacks not monotonically increasing at index {i}: {rank0_callbacks}\"\n            )\n\n    @pytest.mark.parametrize(\n        \"prompt_tokens\",\n        [50, 500],\n        ids=[\"short_50\", \"medium_500\"],\n    )\n    def test_output_matches(self, prompt_tokens: int) -> None:\n        \"\"\"Pipeline-generated text must match single-device output.\"\"\"\n        single = _run_single_device_test(prompt_tokens, timeout=180)\n\n        pipeline_results = _run_pipeline_test(\n            layer_splits=LAYER_SPLITS_4WAY,\n            prompt_tokens=prompt_tokens,\n            base_port=29800,\n            timeout=180,\n        )\n\n        single_text = single[\"text\"]\n\n        # The last rank produces the final logits, so its output should match.\n        # Due to SDPA tiling non-determinism, allow minor differences in text.\n        last_rank = max(pipeline_results.keys())\n        pipe_text = pipeline_results[last_rank][\"text\"]\n\n        # For deterministic sampling (temp=0.0), outputs should match exactly\n        # or be very close. Log both for debugging even if they match.\n        if single_text != pipe_text:\n            # Find first divergence point\n            min_len = min(len(single_text), len(pipe_text))\n            diverge_idx = next(\n                (i for i in range(min_len) if single_text[i] != pipe_text[i]),\n                min_len,\n            )\n            pytest.fail(\n                f\"Output text diverged at character {diverge_idx} for {prompt_tokens} prompt tokens:\\n\"\n                f\"  single-device: {single_text!r}\\n\"\n                f\"  pipeline R{last_rank}: {pipe_text!r}\"\n            )\n\n\nclass TestPipelineCallbacksStructure:\n    \"\"\"Verify structural properties of callbacks independent of model output.\"\"\"\n\n    def test_callback_structure_matches_generate_step(self) -> None:\n        \"\"\"Verify callbacks follow generate_step's pattern: (0,M), chunks up to M-1, (M,M).\"\"\"\n        prompt_tokens = 200\n        pipeline_results = _run_pipeline_test(\n            layer_splits=LAYER_SPLITS_4WAY,\n            prompt_tokens=prompt_tokens,\n            base_port=29900,\n            timeout=180,\n        )\n\n        for rank, pipe_data in sorted(pipeline_results.items()):\n            callbacks = pipe_data[\"callbacks\"]\n            m = pipe_data[\"prefill_token_count\"]\n            assert m > 0, f\"Rank {rank}: prefill token count is 0\"\n\n            assert callbacks[0] == (0, m), (\n                f\"Rank {rank}: first callback should be (0, {m}), got {callbacks[0]}\"\n            )\n\n            assert callbacks[-1] == (m, m), (\n                f\"Rank {rank}: last callback should be ({m}, {m}), got {callbacks[-1]}\"\n            )\n\n            if len(callbacks) > 2:\n                second_to_last = callbacks[-2]\n                assert second_to_last[0] < m, (\n                    f\"Rank {rank}: second-to-last callback should report < {m}, \"\n                    f\"got {second_to_last}\"\n                )\n\n            # All callbacks must have total == M\n            for i, (_, total) in enumerate(callbacks):\n                assert total == m, (\n                    f\"Rank {rank}: callback {i} has total={total}, expected {m}\"\n                )\n\n            # processed values must be non-decreasing\n            processed_vals = [p for p, _ in callbacks]\n            for i in range(1, len(processed_vals)):\n                assert processed_vals[i] >= processed_vals[i - 1], (\n                    f\"Rank {rank}: callbacks not non-decreasing at index {i}: \"\n                    f\"{processed_vals}\"\n                )\n\n            # No duplicate consecutive callbacks (pipeline dummies must not emit callbacks)\n            for i in range(1, len(callbacks)):\n                assert callbacks[i] != callbacks[i - 1], (\n                    f\"Rank {rank}: duplicate consecutive callback at index {i}: \"\n                    f\"{callbacks[i]} (this suggests dummy iterations are emitting callbacks)\"\n                )\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_mlx/test_prefix_cache_architectures.py",
    "content": "import copy\nimport gc\nimport importlib\nimport json\nimport shutil\nimport tempfile\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any, cast\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport pytest\nfrom mlx.utils import tree_flatten, tree_unflatten\nfrom mlx_lm.tokenizer_utils import TokenizerWrapper\n\nfrom exo.shared.types.common import ModelId\nfrom exo.shared.types.mlx import Model\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\nfrom exo.worker.engines.mlx.cache import KVPrefixCache\nfrom exo.worker.engines.mlx.generator.generate import mlx_generate\nfrom exo.worker.engines.mlx.utils_mlx import (\n    apply_chat_template,\n    load_tokenizer_for_model_id,\n)\n\nHF_CACHE = Path.home() / \".cache\" / \"huggingface\" / \"hub\"\n\n# ── Config reduction ──────────────────────────────────────────────────────── #\n\n_REDUCE = {\n    \"num_hidden_layers\": 4,\n    \"hidden_size\": 256,\n    \"num_attention_heads\": 4,\n    \"num_key_value_heads\": 4,\n    \"intermediate_size\": 512,\n    \"moe_intermediate_size\": 128,\n    \"num_experts\": 4,\n    \"num_experts_per_tok\": 2,\n    \"n_routed_experts\": 4,\n    \"num_local_experts\": 4,\n    \"num_nextn_predict_layers\": 0,\n    \"first_k_dense_replace\": 0,\n    \"linear_num_key_heads\": 2,\n    \"linear_num_value_heads\": 2,\n    \"num_attention_groups\": 4,\n}\n\n\ndef _reduce_dict(cfg: dict[str, Any]) -> dict[str, Any]:\n    result = dict(cfg)\n    for key, val in _REDUCE.items():\n        if key in result:\n            result[key] = val\n    return result\n\n\ndef _reduce_config(cfg: dict[str, Any]) -> dict[str, Any]:\n    result = _reduce_dict(cfg)\n    n_layers = cast(int, result.get(\"num_hidden_layers\", 4))\n\n    if \"text_config\" in result and isinstance(result[\"text_config\"], dict):\n        result[\"text_config\"] = _reduce_dict(\n            cast(dict[str, Any], result[\"text_config\"])\n        )\n        tc: dict[str, Any] = result[\"text_config\"]\n        if \"num_nextn_predict_layers\" in tc:\n            tc[\"num_nextn_predict_layers\"] = 0\n\n    if \"layer_types\" in result and isinstance(result[\"layer_types\"], list):\n        result[\"layer_types\"] = result[\"layer_types\"][:n_layers]\n\n    if \"attention_other_setting\" in result and isinstance(\n        result[\"attention_other_setting\"], dict\n    ):\n        aos: dict[str, Any] = dict(\n            cast(dict[str, Any], result[\"attention_other_setting\"])\n        )\n        if \"num_attention_heads\" in aos:\n            aos[\"num_attention_heads\"] = result.get(\"num_attention_heads\", 4)\n        if \"num_attention_groups\" in aos:\n            aos[\"num_attention_groups\"] = result.get(\n                \"num_attention_groups\", cast(int, aos[\"num_attention_groups\"])\n            )\n        result[\"attention_other_setting\"] = aos\n\n    if \"moe_layers_enum\" in result and isinstance(result[\"moe_layers_enum\"], str):\n        indices = [int(x) for x in result[\"moe_layers_enum\"].split(\",\") if x.strip()]\n        valid = [i for i in indices if i < n_layers]\n        result[\"moe_layers_enum\"] = \",\".join(str(i) for i in valid) if valid else \"\"\n\n    return result\n\n\n# ── Helpers ───────────────────────────────────────────────────────────────── #\n\n\ndef _find_snapshot(hub_name: str) -> Path | None:\n    model_dir = HF_CACHE / f\"models--mlx-community--{hub_name}\"\n    snaps = model_dir / \"snapshots\"\n    if not snaps.exists():\n        return None\n    children = sorted(snaps.iterdir())\n    return children[0] if children else None\n\n\ndef _copy_tokenizer(src: Path, dst: Path) -> None:\n    for f in src.iterdir():\n        name = f.name\n        if (\n            \"tokeniz\" in name.lower()\n            or \"tiktoken\" in name.lower()\n            or name.startswith(\"vocab\")\n            or name.endswith(\".jinja\")\n            or \"tool_declaration\" in name\n        ) and f.is_file():\n            shutil.copy2(f, dst / name)\n\n\ndef _build_model(module_name: str, cfg: dict[str, Any]) -> Model:\n    mod = importlib.import_module(f\"mlx_lm.models.{module_name}\")\n    args = mod.ModelArgs.from_dict(cfg)  # pyright: ignore[reportAny]\n    model: nn.Module = mod.Model(args)  # pyright: ignore[reportAny]\n    flat = cast(list[tuple[str, mx.array]], tree_flatten(model.parameters()))\n    random_weights = [\n        (k, mx.random.normal(shape=v.shape, dtype=mx.float16)) for k, v in flat\n    ]\n    model.update(cast(dict[str, Any], tree_unflatten(random_weights)))\n    mx.eval(model.parameters())\n    return cast(Model, model)\n\n\ndef _collect_tokens(\n    model: Model,\n    tokenizer: TokenizerWrapper,\n    task: TextGenerationTaskParams,\n    prompt: str,\n    kv_prefix_cache: KVPrefixCache | None,\n) -> list[int]:\n    tokens: list[int] = []\n    for resp in mlx_generate(\n        model=model,\n        tokenizer=tokenizer,\n        task=task,\n        prompt=prompt,\n        kv_prefix_cache=kv_prefix_cache,\n        group=None,\n    ):\n        tokens.append(resp.token)\n        if resp.finish_reason is not None:\n            break\n    return tokens\n\n\n# ── Architecture definitions ──────────────────────────────────────────────── #\n\n\n@dataclass(frozen=True)\nclass ArchSpec:\n    name: str\n    hub_name: str\n    module: str\n    tokenizer_hub: str | None = None  # fallback for models without bundled tokenizer\n\n\nARCHITECTURES: list[ArchSpec] = [\n    ArchSpec(\"llama\", \"Llama-3.2-1B-Instruct-4bit\", \"llama\"),\n    ArchSpec(\"glm_moe_dsa\", \"GLM-5-MXFP4-Q8\", \"glm_moe_dsa\"),\n    ArchSpec(\n        \"glm4_moe\", \"GLM-4.5-Air-8bit\", \"glm4_moe\", tokenizer_hub=\"GLM-4.7-8bit-gs32\"\n    ),\n    ArchSpec(\n        \"glm4_moe_lite\",\n        \"GLM-4.7-Flash-8bit\",\n        \"glm4_moe_lite\",\n        tokenizer_hub=\"GLM-4.7-8bit-gs32\",\n    ),\n    ArchSpec(\"glm4_moe_47\", \"GLM-4.7-8bit-gs32\", \"glm4_moe\"),\n    ArchSpec(\"qwen3\", \"Qwen3-4B-Instruct-2507-4bit\", \"qwen3\"),\n    ArchSpec(\"qwen3_moe\", \"Qwen3-30B-A3B-4bit\", \"qwen3_moe\"),\n    ArchSpec(\"qwen3_next\", \"Qwen3-Next-80B-A3B-Thinking-4bit\", \"qwen3_next\"),\n    ArchSpec(\"minimax\", \"MiniMax-M2.1-3bit\", \"minimax\"),\n    ArchSpec(\"gpt_oss\", \"gpt-oss-20b-MXFP4-Q8\", \"gpt_oss\"),\n    ArchSpec(\"step3p5\", \"Step-3.5-Flash-4bit\", \"step3p5\"),\n    ArchSpec(\"kimi_k25\", \"Kimi-K2.5\", \"kimi_k25\"),\n    ArchSpec(\"qwen3_5\", \"Qwen3.5-2B-MLX-8bit\", \"qwen3_5\"),\n    ArchSpec(\"qwen3_5_moe\", \"Qwen3.5-35B-A3B-4bit\", \"qwen3_5_moe\"),\n]\n\n\ndef _arch_available(spec: ArchSpec) -> bool:\n    snap = _find_snapshot(spec.hub_name)\n    if snap is None:\n        return False\n    if spec.tokenizer_hub is not None:\n        return _find_snapshot(spec.tokenizer_hub) is not None\n    return True\n\n\ndef _make_task() -> TextGenerationTaskParams:\n    return TextGenerationTaskParams(\n        model=ModelId(\"test\"),\n        input=[\n            InputMessage(\n                role=\"user\",\n                content=\"Use the calculator to compute 1847 * 263 + 5921\",\n            )\n        ],\n        max_output_tokens=20,\n        temperature=0.0,\n        tools=[\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"calculate\",\n                    \"description\": \"Evaluate a mathematical expression\",\n                    \"parameters\": {\n                        \"type\": \"object\",\n                        \"properties\": {\"expression\": {\"type\": \"string\"}},\n                        \"required\": [\"expression\"],\n                    },\n                },\n            }\n        ],\n    )\n\n\n# ── Test class ────────────────────────────────────────────────────────────── #\n\n\n@pytest.mark.slow\nclass TestPrefixCacheArchitectures:\n    \"\"\"Verify prefix cache produces identical output to fresh generation for every architecture.\"\"\"\n\n    @pytest.fixture(autouse=True)\n    def _cleanup(self):\n        yield\n        mx.clear_cache()\n        gc.collect()\n\n    @pytest.mark.parametrize(\n        \"spec\",\n        ARCHITECTURES,\n        ids=[a.name for a in ARCHITECTURES],\n    )\n    def test_prefix_cache_exact_hit(self, spec: ArchSpec) -> None:\n        if not _arch_available(spec):\n            pytest.skip(f\"Model {spec.hub_name} not cached locally\")\n\n        snapshot = _find_snapshot(spec.hub_name)\n        assert snapshot is not None\n\n        tmpdir = Path(tempfile.mkdtemp(prefix=f\"exo_test_{spec.name}_\"))\n        try:\n            # Build reduced config\n            with open(snapshot / \"config.json\") as f:\n                cfg = cast(dict[str, Any], json.load(f))\n            reduced = _reduce_config(copy.deepcopy(cfg))\n            (tmpdir / \"config.json\").write_text(json.dumps(reduced))\n\n            # Copy tokenizer\n            tok_src = snapshot\n            if spec.tokenizer_hub is not None:\n                alt = _find_snapshot(spec.tokenizer_hub)\n                if alt is not None:\n                    tok_src = alt\n            _copy_tokenizer(tok_src, tmpdir)\n\n            # Load tokenizer and model\n            model_id = ModelId(f\"mlx-community/{spec.hub_name}\")\n            tokenizer = load_tokenizer_for_model_id(model_id, tmpdir)\n            mx.random.seed(0)\n            model = _build_model(spec.module, reduced)\n\n            task = _make_task()\n            prompt = apply_chat_template(tokenizer=tokenizer, task_params=task)\n\n            # Run 1: fresh\n            mx.random.seed(42)\n            fresh = _collect_tokens(model, tokenizer, task, prompt, None)\n            assert len(fresh) > 0, \"Fresh generation produced no tokens\"\n\n            # Run 2: populate cache\n            kv = KVPrefixCache(None)\n            mx.random.seed(42)\n            populate = _collect_tokens(model, tokenizer, task, prompt, kv)\n\n            # Run 3: exact cache hit\n            mx.random.seed(42)\n            cached = _collect_tokens(model, tokenizer, task, prompt, kv)\n\n            assert fresh == populate, (\n                f\"Fresh vs populate mismatch: {fresh[:5]} vs {populate[:5]}\"\n            )\n            assert fresh == cached, (\n                f\"Fresh vs cached mismatch: {fresh[:5]} vs {cached[:5]}\"\n            )\n        finally:\n            shutil.rmtree(tmpdir, ignore_errors=True)\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_mlx/test_tokenizers.py",
    "content": "\"\"\"\nUnit tests for tokenizer loading and functionality across all supported models.\n\nThis test downloads only tokenizer-related files (not full model weights) to verify\nthat tokenizers can be loaded and used correctly for encoding/decoding.\n\"\"\"\n\nimport asyncio\nimport contextlib\nfrom pathlib import Path\n\nimport pytest\n\nfrom exo.download.download_utils import (\n    download_file_with_retry,\n    ensure_models_dir,\n    fetch_file_list_with_cache,\n)\nfrom exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards\nfrom exo.worker.engines.mlx.utils_mlx import (\n    get_eos_token_ids_for_model,\n    load_tokenizer_for_model_id,\n)\n\n# Files needed for tokenizer functionality\nTOKENIZER_FILE_PATTERNS = [\n    \"tokenizer.json\",\n    \"tokenizer_config.json\",\n    \"special_tokens_map.json\",\n    \"vocab.json\",\n    \"vocab.txt\",\n    \"merges.txt\",\n    \"tiktoken.model\",\n    \"added_tokens.json\",\n    \"tokenizer.model\",\n    \"tokenization_*.py\",  # Custom tokenizer implementations\n    \"tool_declaration_ts.py\",  # Dependency of tokenization_kimi.py\n]\n\n\ndef is_tokenizer_file(filename: str) -> bool:\n    \"\"\"Check if a file is needed for tokenizer functionality.\"\"\"\n    for pattern in TOKENIZER_FILE_PATTERNS:\n        if \"*\" in pattern:\n            prefix = pattern.split(\"*\")[0]\n            suffix = pattern.split(\"*\")[1]\n            if filename.startswith(prefix) and filename.endswith(suffix):\n                return True\n        elif filename == pattern:\n            return True\n    return False\n\n\nasync def download_tokenizer_files(model_id: ModelId) -> Path:\n    \"\"\"Download only the tokenizer-related files for a model.\"\"\"\n    target_dir = await ensure_models_dir() / model_id.normalize()\n    target_dir.mkdir(parents=True, exist_ok=True)\n\n    file_list = await fetch_file_list_with_cache(model_id, \"main\", recursive=True)\n\n    tokenizer_files = [f for f in file_list if is_tokenizer_file(f.path)]\n\n    if not tokenizer_files:\n        pytest.skip(f\"No tokenizer files found for {model_id}\")\n\n    for file_entry in tokenizer_files:\n        with contextlib.suppress(FileNotFoundError):\n            await download_file_with_retry(\n                model_id, \"main\", file_entry.path, target_dir\n            )\n\n    return target_dir\n\n\n# Get a sample of models to test (one per family to keep tests fast)\ndef get_test_models() -> list[ModelCard]:\n    \"\"\"Get a representative sample of models to test.\"\"\"\n    # Pick one model from each family to test\n    families: dict[str, ModelCard] = {}\n    for card in asyncio.run(get_model_cards()):\n        # Extract family name (e.g., \"llama-3.1\" from \"llama-3.1-8b\")\n        parts = card.model_id.short().split(\"-\")\n        family = \"-\".join(parts[:2]) if len(parts) >= 2 else parts[0]\n\n        if family not in families:\n            families[family] = card\n\n    return list(families.values())\n\n\nTEST_MODELS: list[ModelCard] = get_test_models()\n\npytestmark = pytest.mark.slow\n\n\n@pytest.fixture(scope=\"module\")\ndef event_loop():\n    \"\"\"Create event loop for async tests.\"\"\"\n    loop = asyncio.new_event_loop()\n    yield loop\n    loop.close()\n\n\n@pytest.mark.parametrize(\n    \"model_card\",\n    TEST_MODELS,\n)\n@pytest.mark.asyncio\nasync def test_tokenizer_encode_decode(model_card: ModelCard) -> None:\n    \"\"\"Test that tokenizer can encode and decode text correctly.\"\"\"\n    model_id = model_card.model_id\n\n    # Download tokenizer files\n    model_path = await download_tokenizer_files(model_id)\n\n    # Verify required files exist\n    has_tokenizer = (\n        (model_path / \"tokenizer.json\").exists()\n        or (model_path / \"tokenizer_config.json\").exists()\n        or (model_path / \"tiktoken.model\").exists()\n        or (model_path / \"tokenizer.model\").exists()\n    )\n    if not has_tokenizer:\n        pytest.skip(f\"Required tokenizer files not found for {model_id}\")\n\n    # Load tokenizer\n    tokenizer = load_tokenizer_for_model_id(model_id, model_path)\n\n    # Test basic encoding\n    test_text = \"Hello, world!\"\n    encoded = tokenizer.encode(test_text)\n    assert isinstance(encoded, list), f\"encode() should return a list for {model_id}\"\n    assert len(encoded) > 0, f\"encode() should return non-empty list for {model_id}\"\n    assert all(isinstance(t, int) for t in encoded), (\n        f\"All tokens should be integers for {model_id}\"\n    )\n\n    # Test decoding\n    decoded = tokenizer.decode(encoded)\n    assert isinstance(decoded, str), f\"decode() should return a string for {model_id}\"\n    assert test_text in decoded or decoded.strip() == test_text.strip(), (\n        f\"decode(encode(x)) should preserve text for {model_id}: got {decoded!r}\"\n    )\n\n    # Test with longer text\n    long_text = \"The quick brown fox jumps over the lazy dog. \" * 10\n    long_encoded = tokenizer.encode(long_text)\n    assert len(long_encoded) > len(encoded), (\n        f\"Longer text should produce more tokens for {model_id}\"\n    )\n\n    # Test empty string\n    empty_encoded = tokenizer.encode(\"\")\n    assert isinstance(empty_encoded, list), (\n        f\"encode('') should return a list for {model_id}\"\n    )\n\n    # Test special characters\n    special_text = 'Hello!\\n\\tWorld? <test> & \"quotes\"'\n    special_encoded = tokenizer.encode(special_text)\n    assert len(special_encoded) > 0, f\"Special chars should encode for {model_id}\"\n\n    # Test unicode\n    unicode_text = \"Hello 世界 🌍\"\n    unicode_encoded = tokenizer.encode(unicode_text)\n    assert len(unicode_encoded) > 0, f\"Unicode should encode for {model_id}\"\n\n\n@pytest.mark.parametrize(\n    \"model_card\",\n    TEST_MODELS,\n)\n@pytest.mark.asyncio\nasync def test_tokenizer_has_required_attributes(model_card: ModelCard) -> None:\n    \"\"\"Test that tokenizer has required attributes for inference.\"\"\"\n    model_id = model_card.model_id\n\n    model_path = await download_tokenizer_files(model_id)\n\n    has_tokenizer = (\n        (model_path / \"tokenizer.json\").exists()\n        or (model_path / \"tokenizer_config.json\").exists()\n        or (model_path / \"tiktoken.model\").exists()\n        or (model_path / \"tokenizer.model\").exists()\n    )\n    if not has_tokenizer:\n        pytest.skip(f\"Required tokenizer files not found for {model_id}\")\n\n    tokenizer = load_tokenizer_for_model_id(model_id, model_path)\n    eos_token_ids = get_eos_token_ids_for_model(model_id)\n\n    # Check for vocabulary size\n    empty_vocab: dict[str, int] = {}\n    vocab_size: int = getattr(tokenizer, \"vocab_size\", None) or len(\n        getattr(tokenizer, \"get_vocab\", lambda: empty_vocab)()\n    )\n    assert vocab_size > 0, f\"Tokenizer should have vocab_size > 0 for {model_id}\"\n\n    # Check for EOS token (either from tokenizer or explicitly provided)\n    has_eos = (\n        eos_token_ids is not None\n        or getattr(tokenizer, \"eos_token_id\", None) is not None\n        or getattr(tokenizer, \"eos_token\", None) is not None\n    )\n    assert has_eos, f\"Tokenizer should have EOS token for {model_id}\"\n\n\n@pytest.mark.parametrize(\n    \"model_card\",\n    TEST_MODELS,\n)\n@pytest.mark.asyncio\nasync def test_tokenizer_special_tokens(model_card: ModelCard) -> None:\n    \"\"\"Test that tokenizer can encode text containing special tokens.\n\n    This is critical because the actual inference path uses prompts with\n    special tokens from chat templates. If special tokens aren't handled\n    correctly, encoding will fail.\n    \"\"\"\n    model_id = model_card.model_id\n\n    model_path = await download_tokenizer_files(model_id)\n\n    has_tokenizer = (\n        (model_path / \"tokenizer.json\").exists()\n        or (model_path / \"tokenizer_config.json\").exists()\n        or (model_path / \"tiktoken.model\").exists()\n        or (model_path / \"tokenizer.model\").exists()\n    )\n    assert has_tokenizer, f\"Required tokenizer files not found for {model_id}\"\n\n    tokenizer = load_tokenizer_for_model_id(model_id, model_path)\n\n    # Get special tokens from the tokenizer\n    special_tokens: list[str] = []\n\n    # Try to get special tokens from various sources\n    if hasattr(tokenizer, \"all_special_tokens\"):\n        special_tokens.extend(tokenizer.all_special_tokens)\n    elif hasattr(tokenizer, \"_tokenizer\") and hasattr(\n        tokenizer._tokenizer,\n        \"all_special_tokens\",\n    ):\n        special_tokens.extend(tokenizer._tokenizer.all_special_tokens)\n\n    # Also check for common special token attributes\n    for attr in [\n        \"bos_token\",\n        \"eos_token\",\n        \"pad_token\",\n        \"unk_token\",\n        \"sep_token\",\n        \"cls_token\",\n    ]:\n        token = getattr(tokenizer, attr, None)\n        if token is None and hasattr(tokenizer, \"_tokenizer\"):\n            token = getattr(tokenizer._tokenizer, attr, None)\n        if token and isinstance(token, str) and token not in special_tokens:\n            special_tokens.append(token)\n\n    # If we found special tokens, test encoding text that contains them\n    if special_tokens:\n        # Create text with special tokens interspersed\n        test_with_special = f\"{special_tokens[0]}Hello world\"\n        if len(special_tokens) > 1:\n            test_with_special += f\"{special_tokens[1]}\"\n\n        encoded = tokenizer.encode(test_with_special)\n        assert isinstance(encoded, list), (\n            f\"encode() with special tokens should return list for {model_id}\"\n        )\n        assert len(encoded) > 0, (\n            f\"encode() with special tokens should return non-empty list for {model_id}\"\n        )\n        assert all(isinstance(t, int) for t in encoded), (\n            f\"All tokens should be integers for {model_id}\"\n        )\n\n        # Verify we can decode\n        decoded = tokenizer.decode(encoded)\n        assert isinstance(decoded, str), f\"decode() should return string for {model_id}\"\n\n    # Test with angle-bracket tokens (common format for special tokens)\n    # These should not raise errors even if they're not actual special tokens\n    angle_bracket_text = \"<|test|>Hello<|end|>\"\n    encoded = tokenizer.encode(angle_bracket_text)\n    assert isinstance(encoded, list), (\n        f\"encode() with angle brackets should return list for {model_id}\"\n    )\n    assert len(encoded) > 0, (\n        f\"encode() with angle brackets should be non-empty for {model_id}\"\n    )\n\n\n# Specifically test Kimi tokenizer since it has special handling\n@pytest.mark.asyncio\nasync def test_kimi_tokenizer_specifically():\n    \"\"\"Test Kimi tokenizer with its specific patches and quirks.\"\"\"\n    kimi_models = [\n        card for card in await get_model_cards() if \"kimi\" in card.model_id.lower()\n    ]\n\n    if not kimi_models:\n        pytest.skip(\"No Kimi models found in MODEL_CARDS\")\n\n    model_card = kimi_models[0]\n    model_id = model_card.model_id\n\n    model_path = await download_tokenizer_files(model_id)\n\n    # Ensure the custom tokenizer file exists\n    if not (model_path / \"tokenization_kimi.py\").exists():\n        pytest.skip(\"tokenization_kimi.py not found\")\n\n    tokenizer = load_tokenizer_for_model_id(model_id, model_path)\n    eos_token_ids = get_eos_token_ids_for_model(model_id)\n\n    # Test encode/decode cycle\n    test_text = \"Hello, world!\"\n    encoded = tokenizer.encode(test_text)\n    decoded = tokenizer.decode(encoded)\n\n    assert len(encoded) > 0, \"Kimi tokenizer should encode text\"\n    assert isinstance(decoded, str), \"Kimi tokenizer should decode to string\"\n\n    # Test that the patched encode works (returns list of ints)\n    assert all(isinstance(t, int) for t in encoded), \"Tokens should be integers\"\n\n    # Test encoding text with special tokens (like from chat templates)\n    # This is critical - the warmup inference uses prompts with special tokens\n    special_token_text = \"<|im_user|>user<|im_middle|>Hello<|im_end|><|im_assistant|>\"\n    special_encoded = tokenizer.encode(special_token_text)\n    assert len(special_encoded) > 0, \"Kimi tokenizer should handle special tokens\"\n    assert all(isinstance(t, int) for t in special_encoded), (\n        \"Special token encoding should return integers\"\n    )\n\n    # Verify EOS token is set\n    assert eos_token_ids == [163586], \"Kimi EOS token should be [163586]\"\n\n\n# Test GLM tokenizer since it also has special handling\n@pytest.mark.asyncio\nasync def test_glm_tokenizer_specifically():\n    \"\"\"Test GLM tokenizer with its specific EOS tokens.\"\"\"\n\n    def contains(card: ModelCard, x: str):\n        return x in card.model_id.lower()\n\n    glm_model_cards = [\n        card\n        for card in await get_model_cards()\n        if contains(card, \"glm\")\n        and not contains(card, \"-5\")\n        and not contains(card, \"4.7\")\n    ]\n\n    if not glm_model_cards:\n        pytest.skip(\"No GLM models found in MODEL_CARDS\")\n\n    model_card = glm_model_cards[0]\n    model_id = model_card.model_id\n\n    model_path = await download_tokenizer_files(model_id)\n\n    has_tokenizer = (model_path / \"tokenizer.json\").exists() or (\n        model_path / \"tokenizer_config.json\"\n    ).exists()\n    if not has_tokenizer:\n        pytest.skip(\"GLM tokenizer files not found\")\n\n    tokenizer = load_tokenizer_for_model_id(model_id, model_path)\n    eos_token_ids = get_eos_token_ids_for_model(model_id)\n\n    # Test encode/decode\n    test_text = \"Hello, world!\"\n    encoded = tokenizer.encode(test_text)\n    decoded = tokenizer.decode(encoded)\n\n    assert len(encoded) > 0, \"GLM tokenizer should encode text\"\n    assert isinstance(decoded, str), \"GLM tokenizer should decode to string\"\n\n    # Verify EOS tokens\n    assert eos_token_ids == [\n        151336,\n        151329,\n        151338,\n    ], \"GLM EOS tokens should be correct\"\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_plan/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/worker/tests/unittests/test_plan/test_download_and_loading.py",
    "content": "import exo.worker.plan as plan_mod\nfrom exo.shared.types.common import NodeId\nfrom exo.shared.types.memory import Memory\nfrom exo.shared.types.tasks import LoadModel\nfrom exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress\nfrom exo.shared.types.worker.instances import BoundInstance\nfrom exo.shared.types.worker.runners import (\n    RunnerConnected,\n    RunnerIdle,\n)\nfrom exo.worker.tests.constants import (\n    INSTANCE_1_ID,\n    MODEL_A_ID,\n    NODE_A,\n    NODE_B,\n    RUNNER_1_ID,\n    RUNNER_2_ID,\n)\nfrom exo.worker.tests.unittests.conftest import (\n    FakeRunnerSupervisor,\n    get_mlx_ring_instance,\n    get_pipeline_shard_metadata,\n)\n\n\ndef test_plan_requests_download_when_waiting_and_shard_not_downloaded():\n    \"\"\"\n    When a runner is waiting for a model and its shard is not in the\n    local download_status map, plan() should emit DownloadModel.\n    \"\"\"\n\n    shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID},\n        runner_to_shard={RUNNER_1_ID: shard},\n    )\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())\n\n    runners = {RUNNER_1_ID: runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {RUNNER_1_ID: RunnerIdle()}\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert isinstance(result, plan_mod.DownloadModel)\n    assert result.instance_id == INSTANCE_1_ID\n    assert result.shard_metadata == shard\n\n\ndef test_plan_loads_model_when_all_shards_downloaded_and_waiting():\n    \"\"\"\n    When all shards for an instance are DownloadCompleted (globally) and\n    all runners are in waiting/loading/loaded states, plan() should emit\n    LoadModel once.\n    \"\"\"\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},\n    )\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerConnected()\n    )\n\n    runners = {RUNNER_1_ID: local_runner}\n    instances = {INSTANCE_1_ID: instance}\n\n    all_runners = {\n        RUNNER_1_ID: RunnerConnected(),\n        RUNNER_2_ID: RunnerConnected(),\n    }\n\n    global_download_status = {\n        NODE_A: [\n            DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())\n        ],\n        NODE_B: [\n            DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())\n        ],\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status=global_download_status,\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert isinstance(result, LoadModel)\n    assert result.instance_id == INSTANCE_1_ID\n\n\ndef test_plan_does_not_request_download_when_shard_already_downloaded():\n    \"\"\"\n    If the local shard already has a DownloadCompleted entry, plan()\n    should not re-emit DownloadModel while global state is still catching up.\n    \"\"\"\n    shard = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID},\n        runner_to_shard={RUNNER_1_ID: shard},\n    )\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())\n\n    runners = {RUNNER_1_ID: runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {RUNNER_1_ID: RunnerIdle()}\n\n    # Global state shows shard is downloaded for NODE_A\n    global_download_status: dict[NodeId, list[DownloadProgress]] = {\n        NODE_A: [\n            DownloadCompleted(shard_metadata=shard, node_id=NODE_A, total=Memory())\n        ],\n        NODE_B: [],\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status=global_download_status,\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert not isinstance(result, plan_mod.DownloadModel)\n\n\ndef test_plan_does_not_load_model_until_all_shards_downloaded_globally():\n    \"\"\"\n    LoadModel should not be emitted while some shards are still missing from\n    the global_download_status.\n    \"\"\"\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},\n    )\n\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerConnected()\n    )\n\n    runners = {RUNNER_1_ID: local_runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {\n        RUNNER_1_ID: RunnerConnected(),\n        RUNNER_2_ID: RunnerConnected(),\n    }\n\n    global_download_status = {\n        NODE_A: [\n            DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())\n        ],\n        NODE_B: [],  # NODE_B has no downloads completed yet\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status=global_download_status,\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert result is None\n\n    global_download_status = {\n        NODE_A: [\n            DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())\n        ],\n        NODE_B: [\n            DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())\n        ],  # NODE_B has no downloads completed yet\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status=global_download_status,\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert result is not None\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_plan/test_runner_lifecycle.py",
    "content": "from typing import Any\n\nimport exo.worker.plan as plan_mod\nfrom exo.shared.types.tasks import Shutdown\nfrom exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId\nfrom exo.shared.types.worker.runners import (\n    RunnerFailed,\n    RunnerId,\n    RunnerReady,\n    RunnerStatus,\n)\nfrom exo.worker.tests.constants import (\n    INSTANCE_1_ID,\n    MODEL_A_ID,\n    NODE_A,\n    NODE_B,\n    RUNNER_1_ID,\n    RUNNER_2_ID,\n)\nfrom exo.worker.tests.unittests.conftest import (\n    FakeRunnerSupervisor,\n    get_mlx_ring_instance,\n    get_pipeline_shard_metadata,\n)\n\n\ndef test_plan_kills_runner_when_instance_missing():\n    \"\"\"\n    If a local runner's instance is no longer present in state,\n    plan() should return a Shutdown for that runner.\n    \"\"\"\n    shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID},\n        runner_to_shard={RUNNER_1_ID: shard},\n    )\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())\n\n    runners = {RUNNER_1_ID: runner}\n    instances: dict[InstanceId, Instance] = {}\n    all_runners = {RUNNER_1_ID: RunnerReady()}\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore[arg-type]\n        global_download_status={NODE_A: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert isinstance(result, Shutdown)\n    assert result.instance_id == INSTANCE_1_ID\n    assert result.runner_id == RUNNER_1_ID\n\n\ndef test_plan_kills_runner_when_sibling_failed():\n    \"\"\"\n    If a sibling runner in the same instance has failed, the local runner\n    should be shut down.\n    \"\"\"\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},\n    )\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())\n\n    runners = {RUNNER_1_ID: runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {\n        RUNNER_1_ID: RunnerReady(),\n        RUNNER_2_ID: RunnerFailed(error_message=\"boom\"),\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore[arg-type]\n        global_download_status={NODE_A: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert isinstance(result, Shutdown)\n    assert result.instance_id == INSTANCE_1_ID\n    assert result.runner_id == RUNNER_1_ID\n\n\ndef test_plan_creates_runner_when_missing_for_node():\n    \"\"\"\n    If shard_assignments specify a runner for this node but we don't have\n    a local supervisor yet, plan() should emit a CreateRunner.\n    \"\"\"\n    shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID},\n        runner_to_shard={RUNNER_1_ID: shard},\n    )\n\n    runners: dict[Any, Any] = {}  # nothing local yet\n    instances = {INSTANCE_1_ID: instance}\n    all_runners: dict[Any, Any] = {}\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,\n        global_download_status={NODE_A: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    # We patched plan_mod.CreateRunner → CreateRunner\n    assert isinstance(result, plan_mod.CreateRunner)\n    assert result.instance_id == INSTANCE_1_ID\n    assert isinstance(result.bound_instance, BoundInstance)\n    assert result.bound_instance.instance is instance\n    assert result.bound_instance.bound_runner_id == RUNNER_1_ID\n\n\ndef test_plan_does_not_create_runner_when_supervisor_already_present():\n    \"\"\"\n    If we already have a local supervisor for the runner assigned to this node,\n    plan() should not emit a CreateRunner again.\n    \"\"\"\n    shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID},\n        runner_to_shard={RUNNER_1_ID: shard},\n    )\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())\n\n    runners = {RUNNER_1_ID: runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {RUNNER_1_ID: RunnerReady()}\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore[arg-type]\n        global_download_status={NODE_A: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert result is None\n\n\ndef test_plan_does_not_create_runner_for_unassigned_node():\n    \"\"\"\n    If this node does not appear in shard_assignments.node_to_runner,\n    plan() should not try to create a runner on this node.\n    \"\"\"\n    shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_2_ID: shard},\n    )\n\n    runners: dict[RunnerId, FakeRunnerSupervisor] = {}  # no local runners\n    instances = {INSTANCE_1_ID: instance}\n    all_runners: dict[RunnerId, RunnerStatus] = {}\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert result is None\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_plan/test_task_forwarding.py",
    "content": "from typing import cast\n\nimport exo.worker.plan as plan_mod\nfrom exo.shared.types.tasks import Task, TaskId, TaskStatus, TextGeneration\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\nfrom exo.shared.types.worker.instances import BoundInstance, InstanceId\nfrom exo.shared.types.worker.runners import (\n    RunnerIdle,\n    RunnerReady,\n    RunnerRunning,\n)\nfrom exo.worker.tests.constants import (\n    COMMAND_1_ID,\n    INSTANCE_1_ID,\n    MODEL_A_ID,\n    NODE_A,\n    NODE_B,\n    RUNNER_1_ID,\n    RUNNER_2_ID,\n    TASK_1_ID,\n)\nfrom exo.worker.tests.unittests.conftest import (\n    FakeRunnerSupervisor,\n    OtherTask,\n    get_mlx_ring_instance,\n    get_pipeline_shard_metadata,\n)\n\n\ndef test_plan_forwards_pending_chat_completion_when_runner_ready():\n    \"\"\"\n    When there is a pending TextGeneration for the local instance and all\n    runners are Ready/Running, plan() should forward that task.\n    \"\"\"\n    shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},\n    )\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerReady()\n    )\n\n    runners = {RUNNER_1_ID: local_runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {\n        RUNNER_1_ID: RunnerReady(),\n        RUNNER_2_ID: RunnerReady(),\n    }\n\n    task = TextGeneration(\n        task_id=TASK_1_ID,\n        instance_id=INSTANCE_1_ID,\n        task_status=TaskStatus.Pending,\n        command_id=COMMAND_1_ID,\n        task_params=TextGenerationTaskParams(\n            model=MODEL_A_ID, input=[InputMessage(role=\"user\", content=\"\")]\n        ),\n    )\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={TASK_1_ID: task},\n    )\n\n    assert result is task\n\n\ndef test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():\n    \"\"\"\n    Even with a pending TextGeneration, plan() should not forward it unless\n    all runners for the instance are Ready/Running.\n    \"\"\"\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},\n    )\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerReady()\n    )\n\n    runners = {RUNNER_1_ID: local_runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {\n        RUNNER_1_ID: RunnerReady(),\n        RUNNER_2_ID: RunnerIdle(),\n    }\n\n    task = TextGeneration(\n        task_id=TASK_1_ID,\n        instance_id=INSTANCE_1_ID,\n        task_status=TaskStatus.Pending,\n        command_id=COMMAND_1_ID,\n        task_params=TextGenerationTaskParams(\n            model=MODEL_A_ID, input=[InputMessage(role=\"user\", content=\"\")]\n        ),\n    )\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: [], NODE_B: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={TASK_1_ID: task},\n    )\n\n    assert result is None\n\n\ndef test_plan_does_not_forward_tasks_for_other_instances():\n    \"\"\"\n    plan() should ignore pending TextGeneration tasks whose instance_id does\n    not match the local instance.\n    \"\"\"\n    shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)\n    local_instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID},\n        runner_to_shard={RUNNER_1_ID: shard},\n    )\n    bound_instance = BoundInstance(\n        instance=local_instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerReady()\n    )\n\n    runners = {RUNNER_1_ID: local_runner}\n    instances = {INSTANCE_1_ID: local_instance}\n    all_runners = {RUNNER_1_ID: RunnerReady()}\n\n    other_instance_id = InstanceId(\"instance-2\")\n    foreign_task = TextGeneration(\n        task_id=TaskId(\"other-task\"),\n        instance_id=other_instance_id,\n        task_status=TaskStatus.Pending,\n        command_id=COMMAND_1_ID,\n        task_params=TextGenerationTaskParams(\n            model=MODEL_A_ID, input=[InputMessage(role=\"user\", content=\"\")]\n        ),\n    )\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={foreign_task.task_id: foreign_task},\n    )\n\n    assert result is None\n\n\ndef test_plan_ignores_non_pending_or_non_chat_tasks():\n    \"\"\"\n    _pending_tasks should not forward tasks that are either not TextGeneration\n    or not in Pending/Running states.\n    \"\"\"\n    shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},\n    )\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerReady()\n    )\n\n    runners = {RUNNER_1_ID: local_runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {\n        RUNNER_1_ID: RunnerReady(),\n        RUNNER_2_ID: RunnerReady(),\n    }\n\n    completed_task = TextGeneration(\n        task_id=TASK_1_ID,\n        instance_id=INSTANCE_1_ID,\n        task_status=TaskStatus.Complete,\n        command_id=COMMAND_1_ID,\n        task_params=TextGenerationTaskParams(\n            model=MODEL_A_ID, input=[InputMessage(role=\"user\", content=\"\")]\n        ),\n    )\n\n    other_task_id = TaskId(\"other-task\")\n\n    other_task = cast(\n        Task,\n        cast(\n            object,\n            OtherTask(\n                task_id=other_task_id,\n                instance_id=INSTANCE_1_ID,\n                task_status=TaskStatus.Pending,\n            ),\n        ),\n    )\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: [], NODE_B: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={TASK_1_ID: completed_task, other_task_id: other_task},\n    )\n\n    assert result is None\n\n\ndef test_plan_returns_none_when_nothing_to_do():\n    \"\"\"\n    If there are healthy runners, no downloads needed, and no pending tasks,\n    plan() should return None (steady state).\n    \"\"\"\n    shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},\n    )\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerRunning()\n    )\n\n    runners = {RUNNER_1_ID: local_runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {\n        RUNNER_1_ID: RunnerRunning(),\n        RUNNER_2_ID: RunnerRunning(),\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: [], NODE_B: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert result is None\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_plan/test_warmup.py",
    "content": "import exo.worker.plan as plan_mod\nfrom exo.shared.types.tasks import StartWarmup\nfrom exo.shared.types.worker.instances import BoundInstance\nfrom exo.shared.types.worker.runners import (\n    RunnerIdle,\n    RunnerLoaded,\n    RunnerLoading,\n    RunnerWarmingUp,\n)\nfrom exo.worker.tests.constants import (\n    INSTANCE_1_ID,\n    MODEL_A_ID,\n    NODE_A,\n    NODE_B,\n    NODE_C,\n    RUNNER_1_ID,\n    RUNNER_2_ID,\n    RUNNER_3_ID,\n)\nfrom exo.worker.tests.unittests.conftest import (\n    FakeRunnerSupervisor,\n    get_mlx_ring_instance,\n    get_pipeline_shard_metadata,\n)\n\n\ndef test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():\n    \"\"\"\n    For non-zero device_rank shards, StartWarmup should be emitted when all\n    shards in the instance are Loaded/WarmingUp.\n    \"\"\"\n    shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=3)\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=3)\n    shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=2, world_size=3)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID, NODE_C: RUNNER_3_ID},\n        runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1, RUNNER_3_ID: shard2},\n    )\n\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B\n    )\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerLoaded()\n    )\n\n    runners = {RUNNER_2_ID: local_runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {\n        RUNNER_1_ID: RunnerLoaded(),\n        RUNNER_2_ID: RunnerLoaded(),\n        RUNNER_3_ID: RunnerWarmingUp(),\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_B,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert isinstance(result, StartWarmup)\n    assert result.instance_id == INSTANCE_1_ID\n\n\ndef test_plan_starts_warmup_for_rank_zero_after_others_warming():\n    \"\"\"\n    For device_rank == 0, StartWarmup should only be emitted once all the\n    other runners in the instance are already warming up.\n    \"\"\"\n    shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},\n    )\n\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerLoaded()\n    )\n\n    runners = {RUNNER_1_ID: local_runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {\n        RUNNER_1_ID: RunnerLoaded(),\n        RUNNER_2_ID: RunnerWarmingUp(),\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert isinstance(result, StartWarmup)\n    assert result.instance_id == INSTANCE_1_ID\n\n\ndef test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warming():\n    \"\"\"\n    Non-zero rank should not start warmup while any shard is not Loaded/WarmingUp.\n    \"\"\"\n    shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},\n    )\n\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B\n    )\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerLoaded()\n    )\n\n    runners = {RUNNER_2_ID: local_runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {\n        RUNNER_1_ID: RunnerIdle(),\n        RUNNER_2_ID: RunnerLoaded(),\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_B,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: [], NODE_B: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert result is None\n\n\ndef test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():\n    \"\"\"\n    Rank-zero shard should not start warmup until all non-zero ranks are\n    already WarmingUp.\n    For accepting ranks (device_rank != 0), StartWarmup should be\n    emitted when all shards in the instance are Loaded/WarmingUp.\n    In a 2-node setup, rank 1 is the accepting rank.\n    \"\"\"\n    shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},\n    )\n\n    # Rank 1 is the accepting rank\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerLoaded()\n    )\n\n    runners = {RUNNER_1_ID: local_runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {\n        RUNNER_1_ID: RunnerLoaded(),\n        RUNNER_2_ID: RunnerLoaded(),\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert result is None\n\n    all_runners = {\n        RUNNER_1_ID: RunnerLoaded(),\n        RUNNER_2_ID: RunnerWarmingUp(),\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert isinstance(result, StartWarmup)\n    assert result.instance_id == INSTANCE_1_ID\n\n\ndef test_plan_starts_warmup_for_connecting_rank_after_others_warming():\n    \"\"\"\n    For connecting rank (device_rank == world_size - 1), StartWarmup should\n    only be emitted once all the other runners are already warming up.\n    In a 2-node setup, rank 1 is the connecting rank.\n    \"\"\"\n    shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},\n    )\n\n    # Rank 1 is the connecting rank\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B\n    )\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerLoaded()\n    )\n\n    runners = {RUNNER_2_ID: local_runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {\n        RUNNER_1_ID: RunnerWarmingUp(),\n        RUNNER_2_ID: RunnerLoaded(),\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_B,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_B: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert isinstance(result, StartWarmup)\n    assert result.instance_id == INSTANCE_1_ID\n\n\ndef test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warming():\n    \"\"\"\n    Accepting rank should not start warmup while any shard is not Loaded/WarmingUp.\n    In a 2-node setup, rank 0 is the accepting rank.\n    \"\"\"\n    shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},\n    )\n\n    # Rank 0 is the accepting rank\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerLoaded()\n    )\n\n    runners = {RUNNER_1_ID: local_runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {\n        RUNNER_1_ID: RunnerLoaded(),\n        RUNNER_2_ID: RunnerLoading(),\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: [], NODE_B: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert result is None\n\n\ndef test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():\n    \"\"\"\n    Connecting rank (device_rank == 0) should not start warmup\n    until all other ranks are already WarmingUp.\n    \"\"\"\n    shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)\n    shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)\n    instance = get_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},\n        runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},\n    )\n\n    # Rank 1 is the connecting rank\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A\n    )\n    local_runner = FakeRunnerSupervisor(\n        bound_instance=bound_instance, status=RunnerLoaded()\n    )\n\n    runners = {RUNNER_1_ID: local_runner}\n    instances = {INSTANCE_1_ID: instance}\n    all_runners = {\n        RUNNER_1_ID: RunnerLoaded(),\n        RUNNER_2_ID: RunnerLoaded(),\n    }\n\n    result = plan_mod.plan(\n        node_id=NODE_A,\n        runners=runners,  # type: ignore\n        global_download_status={NODE_A: [], NODE_B: []},\n        instances=instances,\n        all_runners=all_runners,\n        tasks={},\n    )\n\n    assert result is None\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_runner/__init__.py",
    "content": ""
  },
  {
    "path": "src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py",
    "content": "import json\nfrom collections.abc import Generator\nfrom typing import Any\n\nfrom exo.shared.types.worker.runner_response import (\n    GenerationResponse,\n    ToolCallResponse,\n)\nfrom exo.worker.engines.mlx.dsml_encoding import (\n    ASSISTANT_TOKEN,\n    BOS_TOKEN,\n    DSML_TOKEN,\n    EOS_TOKEN,\n    THINKING_END,\n    THINKING_START,\n    TOOL_CALLS_END,\n    TOOL_CALLS_START,\n    USER_TOKEN,\n    encode_messages,\n    parse_dsml_output,\n)\nfrom exo.worker.runner.llm_inference.model_output_parsers import parse_deepseek_v32\n\n# ── Shared fixtures ──────────────────────────────────────────────\n\n_WEATHER_TOOLS: list[dict[str, Any]] = [\n    {\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": \"get_weather\",\n            \"description\": \"Get the current weather in a given city\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"city\": {\"type\": \"string\", \"description\": \"The city name\"},\n                    \"units\": {\n                        \"type\": \"string\",\n                        \"enum\": [\"celsius\", \"fahrenheit\"],\n                        \"description\": \"Temperature units\",\n                    },\n                },\n                \"required\": [\"city\"],\n            },\n        },\n    },\n    {\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": \"get_time\",\n            \"description\": \"Get the current time in a timezone\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"timezone\": {\"type\": \"string\"},\n                },\n                \"required\": [\"timezone\"],\n            },\n        },\n    },\n]\n\n\ndef _simulate_tokens(\n    texts: list[str],\n    finish_on_last: bool = True,\n) -> Generator[GenerationResponse]:\n    \"\"\"Simulate a model producing tokens from a list of text strings.\"\"\"\n    for i, text in enumerate(texts):\n        is_last = i == len(texts) - 1\n        yield GenerationResponse(\n            text=text,\n            token=i,\n            finish_reason=\"stop\" if (is_last and finish_on_last) else None,\n            usage=None,\n        )\n\n\n# ── Test: Standard text response (no tool calls) ────────────────\n\n\nclass TestE2EStandardResponse:\n    \"\"\"Model generates a plain text response — no tool calling involved.\"\"\"\n\n    def test_plain_text_passthrough(self):\n        \"\"\"Simulate model producing: 'The weather in NYC is 72°F and sunny.'\"\"\"\n        # Step 1: Encode the prompt (with tools available)\n        messages: list[dict[str, Any]] = [\n            {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n            {\"role\": \"user\", \"content\": \"What's the weather in NYC?\"},\n        ]\n        prompt = encode_messages(messages, thinking_mode=\"chat\", tools=_WEATHER_TOOLS)\n\n        # Verify prompt structure\n        assert BOS_TOKEN in prompt\n        assert \"## Tools\" in prompt\n        assert \"get_weather\" in prompt\n        assert f\"{USER_TOKEN}What's the weather in NYC?{ASSISTANT_TOKEN}\" in prompt\n\n        # Step 2: Simulate model response — plain text tokens (no DSML)\n        model_tokens = [\n            \"The weather\",\n            \" in NYC\",\n            \" is 72\",\n            \"°F\",\n            \" and sunny\",\n            \".\",\n        ]\n        results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))\n\n        # Step 3: Verify all tokens pass through as GenerationResponse\n        gen_results = [r for r in results if isinstance(r, GenerationResponse)]\n        tool_results = [r for r in results if isinstance(r, ToolCallResponse)]\n\n        assert len(tool_results) == 0\n        assert len(gen_results) == 6\n        full_text = \"\".join(r.text for r in gen_results)\n        assert full_text == \"The weather in NYC is 72°F and sunny.\"\n        assert gen_results[-1].finish_reason == \"stop\"\n\n\n# ── Test: Tool call response ─────────────────────────────────────\n\n\nclass TestE2EToolCallResponse:\n    \"\"\"Model generates a DSML tool call — realistic token boundaries.\"\"\"\n\n    def test_realistic_tool_call_tokens(self):\n        \"\"\"Simulate model generating a get_weather tool call with realistic token splits.\n\n        Real models split DSML markers across tokens unpredictably.\n        This simulates how DeepSeek V3.2 actually tokenizes DSML output.\n        \"\"\"\n        # Step 1: Encode prompt\n        messages: list[dict[str, Any]] = [\n            {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n            {\"role\": \"user\", \"content\": \"What's the weather in San Francisco?\"},\n        ]\n        prompt = encode_messages(messages, thinking_mode=\"chat\", tools=_WEATHER_TOOLS)\n        assert \"get_weather\" in prompt\n\n        # Step 2: Simulate realistic token-by-token model output\n        # The model first produces some text, then a DSML tool call block\n        model_tokens = [\n            \"I'll check the weather for you.\",\n            \"\\n\\n\",\n            f\"<{DSML_TOKEN}\",  # marker split across tokens\n            \"function_calls>\\n\",\n            f'<{DSML_TOKEN}invoke name=\"get_weather\">\\n',\n            f'<{DSML_TOKEN}parameter name=\"city\" string=\"true\">',\n            \"San Francisco\",\n            f\"</{DSML_TOKEN}parameter>\\n\",\n            f'<{DSML_TOKEN}parameter name=\"units\" string=\"false\">',\n            '\"celsius\"',\n            f\"</{DSML_TOKEN}parameter>\\n\",\n            f\"</{DSML_TOKEN}invoke>\\n\",\n            f\"</{DSML_TOKEN}function_calls>\",\n        ]\n\n        results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))\n\n        # Step 3: Verify\n        gen_results = [r for r in results if isinstance(r, GenerationResponse)]\n        tool_results = [r for r in results if isinstance(r, ToolCallResponse)]\n\n        # Should have text tokens before tool call + one ToolCallResponse\n        assert len(tool_results) == 1\n        assert len(tool_results[0].tool_calls) == 1\n\n        tc = tool_results[0].tool_calls[0]\n        assert tc.name == \"get_weather\"\n        args = json.loads(tc.arguments)  # pyright: ignore[reportAny]\n        assert args[\"city\"] == \"San Francisco\"\n        assert args[\"units\"] == \"celsius\"\n\n        # The text before the tool call should still be yielded\n        text_before = \"\".join(r.text for r in gen_results if not r.is_thinking)\n        assert \"check the weather\" in text_before\n\n    def test_multiple_tool_calls_in_one_block(self):\n        \"\"\"Model generates two tool calls in a single function_calls block.\"\"\"\n        messages: list[dict[str, Any]] = [\n            {\"role\": \"system\", \"content\": \"You are helpful.\"},\n            {\"role\": \"user\", \"content\": \"Weather in NYC and time in EST?\"},\n        ]\n        prompt = encode_messages(messages, thinking_mode=\"chat\", tools=_WEATHER_TOOLS)\n        assert \"get_weather\" in prompt\n        assert \"get_time\" in prompt\n\n        # Simulate model output with two invocations\n        model_tokens = [\n            \"Let me check both.\\n\\n\",\n            TOOL_CALLS_START,\n            \"\\n\",\n            f'<{DSML_TOKEN}invoke name=\"get_weather\">\\n',\n            f'<{DSML_TOKEN}parameter name=\"city\" string=\"true\">NYC</{DSML_TOKEN}parameter>\\n',\n            f\"</{DSML_TOKEN}invoke>\\n\",\n            f'<{DSML_TOKEN}invoke name=\"get_time\">\\n',\n            f'<{DSML_TOKEN}parameter name=\"timezone\" string=\"true\">EST</{DSML_TOKEN}parameter>\\n',\n            f\"</{DSML_TOKEN}invoke>\\n\",\n            TOOL_CALLS_END,\n        ]\n\n        results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))\n        tool_results = [r for r in results if isinstance(r, ToolCallResponse)]\n\n        assert len(tool_results) == 1\n        assert len(tool_results[0].tool_calls) == 2\n        assert tool_results[0].tool_calls[0].name == \"get_weather\"\n        assert tool_results[0].tool_calls[1].name == \"get_time\"\n\n        args0 = json.loads(tool_results[0].tool_calls[0].arguments)  # pyright: ignore[reportAny]\n        args1 = json.loads(tool_results[0].tool_calls[1].arguments)  # pyright: ignore[reportAny]\n        assert args0 == {\"city\": \"NYC\"}\n        assert args1 == {\"timezone\": \"EST\"}\n\n\n# ── Test: Multi-turn tool use flow ───────────────────────────────\n\n\nclass TestE2EMultiTurnToolUse:\n    \"\"\"Full multi-turn: user asks → model calls tool → tool result → model answers.\"\"\"\n\n    def test_encode_multi_turn_with_tool_results(self):\n        \"\"\"Verify the prompt for turn 2 (after tool results) is correctly encoded.\"\"\"\n        # Turn 1: user asks, model calls tool\n        # Turn 2: tool result provided, model answers\n        messages: list[dict[str, Any]] = [\n            {\"role\": \"system\", \"content\": \"You are a weather assistant.\"},\n            {\"role\": \"user\", \"content\": \"What's the weather in NYC?\"},\n            {\n                \"role\": \"assistant\",\n                \"content\": \"\",\n                \"tool_calls\": [\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_weather\",\n                            \"arguments\": '{\"city\": \"NYC\"}',\n                        },\n                    }\n                ],\n            },\n            {\"role\": \"tool\", \"content\": '{\"temperature\": 72, \"condition\": \"sunny\"}'},\n        ]\n\n        prompt = encode_messages(messages, thinking_mode=\"chat\", tools=_WEATHER_TOOLS)\n\n        # Verify multi-turn structure\n        assert BOS_TOKEN in prompt\n        assert \"You are a weather assistant.\" in prompt\n        assert \"## Tools\" in prompt\n\n        # The assistant's tool call should be encoded as DSML\n        assert TOOL_CALLS_START in prompt\n        assert f'<{DSML_TOKEN}invoke name=\"get_weather\">' in prompt\n        assert EOS_TOKEN in prompt\n\n        # The tool result should be wrapped in function_results\n        assert \"<function_results>\" in prompt\n        assert \"<result>\" in prompt\n        assert \"72\" in prompt\n        assert \"</function_results>\" in prompt\n\n        # Now simulate model answering after seeing the tool result\n        model_tokens = [\n            \"The current\",\n            \" weather in NYC\",\n            \" is 72°F\",\n            \" and sunny.\",\n        ]\n        results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))\n\n        gen_results = [r for r in results if isinstance(r, GenerationResponse)]\n        tool_results = [r for r in results if isinstance(r, ToolCallResponse)]\n\n        assert len(tool_results) == 0\n        full_text = \"\".join(r.text for r in gen_results)\n        assert full_text == \"The current weather in NYC is 72°F and sunny.\"\n\n    def test_multi_tool_results_encoding(self):\n        \"\"\"Verify encoding when model called two tools and both return results.\"\"\"\n        messages: list[dict[str, Any]] = [\n            {\"role\": \"user\", \"content\": \"Weather and time?\"},\n            {\n                \"role\": \"assistant\",\n                \"content\": \"\",\n                \"tool_calls\": [\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_weather\",\n                            \"arguments\": '{\"city\": \"LA\"}',\n                        },\n                    },\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_time\",\n                            \"arguments\": '{\"timezone\": \"PST\"}',\n                        },\n                    },\n                ],\n            },\n            {\"role\": \"tool\", \"content\": \"85F, clear skies\"},\n            {\"role\": \"tool\", \"content\": \"3:42 PM PST\"},\n        ]\n\n        prompt = encode_messages(messages, thinking_mode=\"chat\", tools=_WEATHER_TOOLS)\n\n        # Should have one function_results block with two results\n        assert prompt.count(\"<function_results>\") == 1\n        assert prompt.count(\"</function_results>\") == 1\n        assert \"<result>85F, clear skies</result>\" in prompt\n        assert \"<result>3:42 PM PST</result>\" in prompt\n\n\n# ── Test: Thinking + tool call ───────────────────────────────────\n\n\nclass TestE2EThinkingAndToolCall:\n    \"\"\"Model uses thinking mode, reasons, then makes a tool call.\"\"\"\n\n    def test_thinking_then_tool_call(self):\n        \"\"\"Model thinks first, then produces a DSML tool call block.\"\"\"\n        messages: list[dict[str, Any]] = [\n            {\"role\": \"user\", \"content\": \"What's the weather?\"},\n        ]\n        prompt = encode_messages(\n            messages, tools=_WEATHER_TOOLS, thinking_mode=\"thinking\"\n        )\n        # Thinking mode: prompt should end with <think>\n        assert prompt.endswith(THINKING_START)\n\n        # Simulate: model outputs <think>, thinks, closes thinking, then tool call.\n        # In the full pipeline, parse_thinking_models handles the case where\n        # <think> is in the prompt. Here we test parse_deepseek_v32 directly,\n        # which detects <think>/<think> markers in the stream.\n        model_tokens = [\n            THINKING_START,\n            \"The user wants weather\",\n            \" information. I should use\",\n            \" the get_weather tool.\",\n            THINKING_END,\n            \"\\n\\n\",\n            TOOL_CALLS_START,\n            \"\\n\",\n            f'<{DSML_TOKEN}invoke name=\"get_weather\">\\n',\n            f'<{DSML_TOKEN}parameter name=\"city\" string=\"true\">',\n            \"San Francisco\",\n            f\"</{DSML_TOKEN}parameter>\\n\",\n            f\"</{DSML_TOKEN}invoke>\\n\",\n            TOOL_CALLS_END,\n        ]\n\n        results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))\n\n        gen_results = [r for r in results if isinstance(r, GenerationResponse)]\n        tool_results = [r for r in results if isinstance(r, ToolCallResponse)]\n\n        # Should have thinking tokens + tool call\n        thinking_results = [r for r in gen_results if r.is_thinking]\n\n        assert len(thinking_results) >= 1\n        thinking_text = \"\".join(r.text for r in thinking_results)\n        assert \"get_weather tool\" in thinking_text\n\n        assert len(tool_results) == 1\n        assert tool_results[0].tool_calls[0].name == \"get_weather\"\n        args = json.loads(tool_results[0].tool_calls[0].arguments)  # pyright: ignore[reportAny]\n        assert args[\"city\"] == \"San Francisco\"\n\n    def test_thinking_prompt_encoding(self):\n        \"\"\"Verify thinking mode affects prompt encoding correctly.\"\"\"\n        messages: list[dict[str, Any]] = [\n            {\"role\": \"system\", \"content\": \"Be thorough.\"},\n            {\"role\": \"user\", \"content\": \"What's the weather?\"},\n        ]\n\n        # With thinking enabled\n        prompt_think = encode_messages(\n            messages, tools=_WEATHER_TOOLS, thinking_mode=\"thinking\"\n        )\n        assert prompt_think.endswith(THINKING_START)\n\n        # With thinking disabled\n        prompt_no_think = encode_messages(\n            messages, tools=_WEATHER_TOOLS, thinking_mode=\"chat\"\n        )\n        assert prompt_no_think.endswith(THINKING_END)\n\n        # Both should have the same tool definitions\n        assert \"get_weather\" in prompt_think\n        assert \"get_weather\" in prompt_no_think\n\n\n# ── Test: Round-trip encode → parse ──────────────────────────────\n\n\nclass TestE2ERoundTrip:\n    \"\"\"Verify that DSML we encode can be parsed back correctly.\"\"\"\n\n    def test_encoded_tool_call_is_parseable(self):\n        \"\"\"Encode an assistant tool call message, then parse the DSML output.\"\"\"\n        messages: list[dict[str, Any]] = [\n            {\"role\": \"user\", \"content\": \"Weather?\"},\n            {\n                \"role\": \"assistant\",\n                \"content\": \"\",\n                \"tool_calls\": [\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_weather\",\n                            \"arguments\": '{\"city\": \"Tokyo\", \"units\": \"celsius\"}',\n                        },\n                    }\n                ],\n            },\n        ]\n\n        prompt = encode_messages(messages, thinking_mode=\"chat\", tools=_WEATHER_TOOLS)\n\n        # Extract the DSML function_calls block from the prompt\n        start = prompt.index(TOOL_CALLS_START)\n        end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)\n        dsml_block = prompt[start:end]\n\n        # Parse it back\n        parsed = parse_dsml_output(dsml_block)\n        assert parsed is not None\n        assert len(parsed) == 1\n        assert parsed[0].name == \"get_weather\"\n        args = json.loads(parsed[0].arguments)  # pyright: ignore[reportAny]\n        assert args[\"city\"] == \"Tokyo\"\n        assert args[\"units\"] == \"celsius\"\n\n    def test_encoded_multi_tool_call_round_trips(self):\n        \"\"\"Encode multiple tool calls, verify they parse back correctly.\"\"\"\n        messages: list[dict[str, Any]] = [\n            {\"role\": \"user\", \"content\": \"Both please\"},\n            {\n                \"role\": \"assistant\",\n                \"content\": \"\",\n                \"tool_calls\": [\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_weather\",\n                            \"arguments\": '{\"city\": \"Paris\"}',\n                        },\n                    },\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_time\",\n                            \"arguments\": '{\"timezone\": \"CET\"}',\n                        },\n                    },\n                ],\n            },\n        ]\n\n        prompt = encode_messages(messages, thinking_mode=\"chat\", tools=_WEATHER_TOOLS)\n\n        start = prompt.index(TOOL_CALLS_START)\n        end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)\n        dsml_block = prompt[start:end]\n\n        parsed = parse_dsml_output(dsml_block)\n        assert parsed is not None\n        assert len(parsed) == 2\n        assert parsed[0].name == \"get_weather\"\n        assert parsed[1].name == \"get_time\"\n        assert json.loads(parsed[0].arguments) == {\"city\": \"Paris\"}\n        assert json.loads(parsed[1].arguments) == {\"timezone\": \"CET\"}\n\n\n# ── Test: Edge cases with realistic token boundaries ─────────────\n\n\nclass TestE2EEdgeCases:\n    \"\"\"Edge cases that occur in real model inference.\"\"\"\n\n    def test_dsml_marker_split_at_fullwidth_pipe(self):\n        \"\"\"The fullwidth pipe character ｜ might be its own token.\"\"\"\n        # This is a realistic tokenization: the DSML marker is split at the ｜ chars\n        model_tokens = [\n            \"Let me help.\\n\\n\",\n            \"<\\uff5c\",  # start of ｜DSML｜\n            \"DSML\\uff5c\",  # rest of DSML token\n            \"function_calls>\\n\",\n            f'<{DSML_TOKEN}invoke name=\"get_weather\">\\n',\n            f'<{DSML_TOKEN}parameter name=\"city\" string=\"true\">NYC</{DSML_TOKEN}parameter>\\n',\n            f\"</{DSML_TOKEN}invoke>\\n\",\n            TOOL_CALLS_END,\n        ]\n\n        results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))\n        tool_results = [r for r in results if isinstance(r, ToolCallResponse)]\n\n        assert len(tool_results) == 1\n        assert tool_results[0].tool_calls[0].name == \"get_weather\"\n\n    def test_tool_call_with_nested_json_object(self):\n        \"\"\"Model passes a complex JSON object as a non-string parameter.\"\"\"\n        dsml_block = (\n            f\"{TOOL_CALLS_START}\\n\"\n            f'<{DSML_TOKEN}invoke name=\"create_event\">\\n'\n            f'<{DSML_TOKEN}parameter name=\"title\" string=\"true\">Team Standup</{DSML_TOKEN}parameter>\\n'\n            f'<{DSML_TOKEN}parameter name=\"config\" string=\"false\">'\n            f'{{\"recurring\": true, \"days\": [\"mon\", \"wed\", \"fri\"], \"time\": \"09:00\"}}'\n            f\"</{DSML_TOKEN}parameter>\\n\"\n            f\"</{DSML_TOKEN}invoke>\\n\"\n            f\"{TOOL_CALLS_END}\"\n        )\n\n        # Feed as single token (model might produce it all at once after prefill)\n        results = list(parse_deepseek_v32(_simulate_tokens([dsml_block])))\n        tool_results = [r for r in results if isinstance(r, ToolCallResponse)]\n\n        assert len(tool_results) == 1\n        tc = tool_results[0].tool_calls[0]\n        assert tc.name == \"create_event\"\n        args = json.loads(tc.arguments)  # pyright: ignore[reportAny]\n        assert args[\"title\"] == \"Team Standup\"\n        assert args[\"config\"][\"recurring\"] is True\n        assert args[\"config\"][\"days\"] == [\"mon\", \"wed\", \"fri\"]\n\n    def test_text_with_angle_brackets_not_mistaken_for_dsml(self):\n        \"\"\"Angle brackets in normal text should not trigger DSML buffering.\"\"\"\n        model_tokens = [\n            \"The formula is \",\n            \"<x, y>\",\n            \" where x > 0\",\n            \" and y < 100.\",\n        ]\n\n        results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))\n        gen_results = [r for r in results if isinstance(r, GenerationResponse)]\n        tool_results = [r for r in results if isinstance(r, ToolCallResponse)]\n\n        assert len(tool_results) == 0\n        full_text = \"\".join(r.text for r in gen_results)\n        assert \"formula\" in full_text\n        assert \"<x, y>\" in full_text\n\n    def test_empty_model_response(self):\n        \"\"\"Model produces only EOS (empty response).\"\"\"\n        model_tokens = [\"\"]\n        results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))\n        gen_results = [r for r in results if isinstance(r, GenerationResponse)]\n        assert len(gen_results) == 1\n        assert gen_results[0].text == \"\"\n        assert gen_results[0].finish_reason == \"stop\"\n\n\n# ── Test: Full EPDP spec round-trip ──────────────────────────────\n\n\nclass TestE2EFullRoundTrip:\n    \"\"\"Full round-trip matching the vLLM EPDP spec.\n\n    Simulates the complete multi-turn flow:\n      Turn 1: user asks → think → tool call → tool result → think → answer\n      Turn 2: user asks again → old reasoning stripped → think → answer\n    \"\"\"\n\n    def test_single_tool_full_flow_with_thinking(self):\n        \"\"\"Complete flow: user → think → tool call → tool result → think → answer.\n\n        This is the core EPDP flow from the vLLM spec.\n        \"\"\"\n        # ── Turn 1.1: User asks, encode prompt ──\n        messages: list[dict[str, Any]] = [\n            {\"role\": \"system\", \"content\": \"You are a weather assistant.\"},\n            {\"role\": \"user\", \"content\": \"How's the weather in Hangzhou?\"},\n        ]\n        prompt_1 = encode_messages(\n            messages, tools=_WEATHER_TOOLS, thinking_mode=\"thinking\"\n        )\n        assert prompt_1.endswith(THINKING_START)\n        assert \"## Tools\" in prompt_1\n        assert \"get_weather\" in prompt_1\n\n        # ── Turn 1.1: Model thinks, then calls tool ──\n        model_tokens_1 = [\n            THINKING_START,\n            \"The user wants to know the weather in Hangzhou.\",\n            \" I need to use the get_weather tool.\",\n            THINKING_END,\n            \"\\n\\n\",\n            TOOL_CALLS_START,\n            \"\\n\",\n            f'<{DSML_TOKEN}invoke name=\"get_weather\">\\n',\n            f'<{DSML_TOKEN}parameter name=\"city\" string=\"true\">Hangzhou</{DSML_TOKEN}parameter>\\n',\n            f\"</{DSML_TOKEN}invoke>\\n\",\n            TOOL_CALLS_END,\n        ]\n        results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))\n\n        # Verify: thinking tokens + tool call\n        gen_1 = [r for r in results_1 if isinstance(r, GenerationResponse)]\n        tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]\n        thinking_1 = [r for r in gen_1 if r.is_thinking]\n\n        assert len(thinking_1) >= 1\n        assert \"get_weather tool\" in \"\".join(r.text for r in thinking_1)\n        assert len(tool_1) == 1\n        assert tool_1[0].tool_calls[0].name == \"get_weather\"\n        tc_args = json.loads(tool_1[0].tool_calls[0].arguments)  # pyright: ignore[reportAny]\n        assert tc_args == {\"city\": \"Hangzhou\"}\n\n        # ── Turn 1.2: Add assistant response + tool result to messages ──\n        messages.append(\n            {\n                \"role\": \"assistant\",\n                \"content\": \"\",\n                \"reasoning_content\": \"The user wants to know the weather in Hangzhou. I need to use the get_weather tool.\",\n                \"tool_calls\": [\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_weather\",\n                            \"arguments\": '{\"city\": \"Hangzhou\"}',\n                        },\n                    }\n                ],\n            }\n        )\n        messages.append(\n            {\n                \"role\": \"tool\",\n                \"content\": '{\"temperature\": \"7~13°C\", \"condition\": \"Cloudy\"}',\n            }\n        )\n\n        # Encode prompt for turn 1.2\n        prompt_2 = encode_messages(\n            messages, tools=_WEATHER_TOOLS, thinking_mode=\"thinking\"\n        )\n\n        # Verify: prompt has the full conversation structure\n        assert TOOL_CALLS_START in prompt_2  # assistant's encoded tool call\n        assert EOS_TOKEN in prompt_2  # assistant turn ends with EOS\n        assert \"<function_results>\" in prompt_2\n        assert \"<result>\" in prompt_2\n        assert \"Cloudy\" in prompt_2\n        assert \"</function_results>\" in prompt_2\n        # After tool results with thinking enabled → <think> appended\n        assert prompt_2.endswith(THINKING_START)\n        # The assistant's reasoning_content should appear (it's after last_user_idx)\n        assert \"get_weather tool\" in prompt_2\n\n        # ── Turn 1.2: Model thinks about results, then answers ──\n        model_tokens_2 = [\n            THINKING_START,\n            \"The weather in Hangzhou is Cloudy, 7~13°C.\",\n            \" I'll tell the user.\",\n            THINKING_END,\n            \"The weather in Hangzhou is currently cloudy with temperatures between 7°C and 13°C.\",\n        ]\n        results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))\n\n        gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]\n        tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]\n        thinking_2 = [r for r in gen_2 if r.is_thinking]\n        non_thinking_2 = [r for r in gen_2 if not r.is_thinking]\n\n        assert len(tool_2) == 0  # No more tool calls\n        assert len(thinking_2) >= 1\n        assert \"Cloudy\" in \"\".join(r.text for r in thinking_2)\n        assert len(non_thinking_2) >= 1\n        final_text = \"\".join(r.text for r in non_thinking_2)\n        assert \"7°C\" in final_text\n        assert \"13°C\" in final_text\n\n    def test_multi_tool_full_flow(self):\n        \"\"\"Flow with two tools: user → think → 2 tool calls → 2 results → think → answer.\"\"\"\n        # ── Initial prompt ──\n        messages: list[dict[str, Any]] = [\n            {\"role\": \"system\", \"content\": \"You help with weather and time.\"},\n            {\"role\": \"user\", \"content\": \"Weather in NYC and time in EST?\"},\n        ]\n        prompt_1 = encode_messages(\n            messages, tools=_WEATHER_TOOLS, thinking_mode=\"thinking\"\n        )\n        assert prompt_1.endswith(THINKING_START)\n\n        # ── Model thinks, calls both tools ──\n        model_tokens_1 = [\n            THINKING_START,\n            \"Two requests: weather and time. I'll call both.\",\n            THINKING_END,\n            \"\\n\\n\",\n            TOOL_CALLS_START,\n            \"\\n\",\n            f'<{DSML_TOKEN}invoke name=\"get_weather\">\\n',\n            f'<{DSML_TOKEN}parameter name=\"city\" string=\"true\">NYC</{DSML_TOKEN}parameter>\\n',\n            f\"</{DSML_TOKEN}invoke>\\n\",\n            f'<{DSML_TOKEN}invoke name=\"get_time\">\\n',\n            f'<{DSML_TOKEN}parameter name=\"timezone\" string=\"true\">EST</{DSML_TOKEN}parameter>\\n',\n            f\"</{DSML_TOKEN}invoke>\\n\",\n            TOOL_CALLS_END,\n        ]\n        results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))\n        tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]\n\n        assert len(tool_1) == 1\n        assert len(tool_1[0].tool_calls) == 2\n        assert tool_1[0].tool_calls[0].name == \"get_weather\"\n        assert tool_1[0].tool_calls[1].name == \"get_time\"\n\n        # ── Add assistant + both tool results ──\n        messages.append(\n            {\n                \"role\": \"assistant\",\n                \"content\": \"\",\n                \"reasoning_content\": \"Two requests: weather and time. I'll call both.\",\n                \"tool_calls\": [\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_weather\",\n                            \"arguments\": '{\"city\": \"NYC\"}',\n                        },\n                    },\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_time\",\n                            \"arguments\": '{\"timezone\": \"EST\"}',\n                        },\n                    },\n                ],\n            }\n        )\n        messages.append({\"role\": \"tool\", \"content\": \"72°F, sunny\"})\n        messages.append({\"role\": \"tool\", \"content\": \"2:30 PM EST\"})\n\n        prompt_2 = encode_messages(\n            messages, tools=_WEATHER_TOOLS, thinking_mode=\"thinking\"\n        )\n\n        # Verify multi-tool result encoding\n        # Count is 2: 1 in _TOOLS_SYSTEM_TEMPLATE example + 1 in conversation\n        assert prompt_2.count(\"<function_results>\") == 2\n        assert prompt_2.count(\"</function_results>\") == 2\n        assert \"<result>72°F, sunny</result>\" in prompt_2\n        assert \"<result>2:30 PM EST</result>\" in prompt_2\n        assert prompt_2.endswith(THINKING_START)\n\n        # ── Model thinks about results, answers ──\n        model_tokens_2 = [\n            THINKING_START,\n            \"Got both results. Weather is 72°F sunny, time is 2:30 PM.\",\n            THINKING_END,\n            \"In NYC it's currently 72°F and sunny. The time in EST is 2:30 PM.\",\n        ]\n        results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))\n\n        tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]\n        gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]\n        non_thinking_2 = [r for r in gen_2 if not r.is_thinking]\n\n        assert len(tool_2) == 0\n        final_text = \"\".join(r.text for r in non_thinking_2)\n        assert \"72°F\" in final_text\n        assert \"2:30 PM\" in final_text\n\n    def test_two_user_turns_reasoning_stripped(self):\n        \"\"\"Turn 2: old reasoning_content is stripped from history.\n\n        Per the vLLM spec, clear_reasoning_content is called between user turns\n        to save bandwidth. Our _drop_old_thinking handles this.\n        \"\"\"\n        # Full turn 1 conversation (already completed)\n        messages: list[dict[str, Any]] = [\n            {\"role\": \"system\", \"content\": \"You are helpful.\"},\n            {\"role\": \"user\", \"content\": \"Weather in Hangzhou?\"},\n            {\n                \"role\": \"assistant\",\n                \"content\": \"\",\n                \"reasoning_content\": \"I need to call get_weather for Hangzhou.\",\n                \"tool_calls\": [\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_weather\",\n                            \"arguments\": '{\"city\": \"Hangzhou\"}',\n                        },\n                    }\n                ],\n            },\n            {\"role\": \"tool\", \"content\": \"Cloudy 7~13°C\"},\n            {\n                \"role\": \"assistant\",\n                \"content\": \"The weather in Hangzhou is cloudy, 7-13°C.\",\n                \"reasoning_content\": \"The tool returned cloudy weather. I'll summarize.\",\n            },\n            # Turn 2: user asks again\n            {\"role\": \"user\", \"content\": \"What about Beijing?\"},\n        ]\n\n        prompt = encode_messages(\n            messages, tools=_WEATHER_TOOLS, thinking_mode=\"thinking\"\n        )\n\n        # Old reasoning_content from turn 1 assistants should be STRIPPED\n        # (they're before the last user message at index 5)\n        assert \"I need to call get_weather\" not in prompt\n        assert \"tool returned cloudy\" not in prompt\n\n        # But the assistant's content and tool calls should still be there\n        assert \"cloudy, 7-13°C\" in prompt\n        assert TOOL_CALLS_START in prompt\n\n        # Prompt ends with <think> for the new turn\n        assert prompt.endswith(THINKING_START)\n\n        # ── Turn 2: Model thinks, calls tool for Beijing ──\n        model_tokens = [\n            THINKING_START,\n            \"Now the user wants Beijing weather.\",\n            THINKING_END,\n            \"\\n\\n\",\n            TOOL_CALLS_START,\n            \"\\n\",\n            f'<{DSML_TOKEN}invoke name=\"get_weather\">\\n',\n            f'<{DSML_TOKEN}parameter name=\"city\" string=\"true\">Beijing</{DSML_TOKEN}parameter>\\n',\n            f\"</{DSML_TOKEN}invoke>\\n\",\n            TOOL_CALLS_END,\n        ]\n        results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))\n        tool_results = [r for r in results if isinstance(r, ToolCallResponse)]\n\n        assert len(tool_results) == 1\n        assert tool_results[0].tool_calls[0].name == \"get_weather\"\n        args = json.loads(tool_results[0].tool_calls[0].arguments)  # pyright: ignore[reportAny]\n        assert args == {\"city\": \"Beijing\"}\n\n    def test_chained_tool_calls_loop(self):\n        \"\"\"Model calls tool, gets result, calls another tool, gets result, answers.\n\n        This simulates the inner while loop from the vLLM spec where the model\n        may need multiple sub-turns of tool calling before it has enough info.\n        \"\"\"\n        # ── Sub-turn 1: user asks, model calls get_time ──\n        messages: list[dict[str, Any]] = [\n            {\"role\": \"system\", \"content\": \"You are helpful.\"},\n            {\"role\": \"user\", \"content\": \"What's the weather in Hangzhou tomorrow?\"},\n        ]\n\n        prompt_1 = encode_messages(\n            messages, tools=_WEATHER_TOOLS, thinking_mode=\"thinking\"\n        )\n        assert prompt_1.endswith(THINKING_START)\n\n        # Model first calls get_time to figure out the date\n        model_tokens_1 = [\n            THINKING_START,\n            \"I need the current date first to calculate tomorrow.\",\n            THINKING_END,\n            \"\\n\\n\",\n            TOOL_CALLS_START,\n            \"\\n\",\n            f'<{DSML_TOKEN}invoke name=\"get_time\">\\n',\n            f'<{DSML_TOKEN}parameter name=\"timezone\" string=\"true\">Asia/Shanghai</{DSML_TOKEN}parameter>\\n',\n            f\"</{DSML_TOKEN}invoke>\\n\",\n            TOOL_CALLS_END,\n        ]\n        results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))\n        tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]\n        assert len(tool_1) == 1\n        assert tool_1[0].tool_calls[0].name == \"get_time\"\n\n        # ── Sub-turn 2: add tool result, model calls get_weather ──\n        messages.append(\n            {\n                \"role\": \"assistant\",\n                \"content\": \"\",\n                \"reasoning_content\": \"I need the current date first to calculate tomorrow.\",\n                \"tool_calls\": [\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_time\",\n                            \"arguments\": '{\"timezone\": \"Asia/Shanghai\"}',\n                        },\n                    }\n                ],\n            }\n        )\n        messages.append({\"role\": \"tool\", \"content\": \"2025-12-01 14:30 CST\"})\n\n        prompt_2 = encode_messages(\n            messages, tools=_WEATHER_TOOLS, thinking_mode=\"thinking\"\n        )\n        assert \"<result>2025-12-01 14:30 CST</result>\" in prompt_2\n        assert prompt_2.endswith(THINKING_START)\n\n        # Model now knows the date, calls get_weather\n        model_tokens_2 = [\n            THINKING_START,\n            \"Today is 2025-12-01, so tomorrow is 2025-12-02.\",\n            \" Now I can check weather for Hangzhou.\",\n            THINKING_END,\n            \"\\n\\n\",\n            TOOL_CALLS_START,\n            \"\\n\",\n            f'<{DSML_TOKEN}invoke name=\"get_weather\">\\n',\n            f'<{DSML_TOKEN}parameter name=\"city\" string=\"true\">Hangzhou</{DSML_TOKEN}parameter>\\n',\n            f\"</{DSML_TOKEN}invoke>\\n\",\n            TOOL_CALLS_END,\n        ]\n        results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))\n        tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]\n        assert len(tool_2) == 1\n        assert tool_2[0].tool_calls[0].name == \"get_weather\"\n\n        # ── Sub-turn 3: add weather result, model answers ──\n        messages.append(\n            {\n                \"role\": \"assistant\",\n                \"content\": \"\",\n                \"reasoning_content\": \"Today is 2025-12-01, so tomorrow is 2025-12-02. Now I can check weather for Hangzhou.\",\n                \"tool_calls\": [\n                    {\n                        \"type\": \"function\",\n                        \"function\": {\n                            \"name\": \"get_weather\",\n                            \"arguments\": '{\"city\": \"Hangzhou\"}',\n                        },\n                    }\n                ],\n            }\n        )\n        messages.append({\"role\": \"tool\", \"content\": \"Sunny, 5~12°C\"})\n\n        prompt_3 = encode_messages(\n            messages, tools=_WEATHER_TOOLS, thinking_mode=\"thinking\"\n        )\n        # Should have both function_results blocks (one per tool round)\n        # Count is 3: 1 in _TOOLS_SYSTEM_TEMPLATE example + 2 in conversation\n        assert prompt_3.count(\"<function_results>\") == 3\n        assert prompt_3.count(\"</function_results>\") == 3\n        assert \"<result>2025-12-01 14:30 CST</result>\" in prompt_3\n        assert \"<result>Sunny, 5~12°C</result>\" in prompt_3\n        assert prompt_3.endswith(THINKING_START)\n\n        # Model finally answers\n        model_tokens_3 = [\n            THINKING_START,\n            \"I have the weather for tomorrow in Hangzhou.\",\n            THINKING_END,\n            \"Tomorrow in Hangzhou will be sunny with temperatures between 5°C and 12°C.\",\n        ]\n        results_3 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_3)))\n\n        tool_3 = [r for r in results_3 if isinstance(r, ToolCallResponse)]\n        gen_3 = [r for r in results_3 if isinstance(r, GenerationResponse)]\n        non_thinking_3 = [r for r in gen_3 if not r.is_thinking]\n\n        assert len(tool_3) == 0  # No more tool calls — loop ends\n        final_text = \"\".join(r.text for r in non_thinking_3)\n        assert \"sunny\" in final_text.lower()\n        assert \"5°C\" in final_text\n        assert \"12°C\" in final_text\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_runner/test_event_ordering.py",
    "content": "# Check tasks are complete before runner is ever ready.\nimport unittest.mock\nfrom collections.abc import Iterable\nfrom typing import Callable\n\nimport mlx.core as mx\nimport pytest\n\nimport exo.worker.runner.llm_inference.batch_generator as mlx_batch_generator\nimport exo.worker.runner.llm_inference.model_output_parsers as mlx_model_output_parsers\nimport exo.worker.runner.llm_inference.runner as mlx_runner\nfrom exo.shared.types.chunks import TokenChunk\nfrom exo.shared.types.events import (\n    ChunkGenerated,\n    Event,\n    RunnerStatusUpdated,\n    TaskAcknowledged,\n    TaskStatusUpdated,\n)\nfrom exo.shared.types.tasks import (\n    ConnectToGroup,\n    LoadModel,\n    Shutdown,\n    StartWarmup,\n    Task,\n    TaskId,\n    TaskStatus,\n    TextGeneration,\n)\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\nfrom exo.shared.types.worker.runner_response import GenerationResponse\nfrom exo.shared.types.worker.runners import (\n    RunnerConnected,\n    RunnerConnecting,\n    RunnerIdle,\n    RunnerLoaded,\n    RunnerLoading,\n    RunnerReady,\n    RunnerRunning,\n    RunnerShutdown,\n    RunnerShuttingDown,\n    RunnerWarmingUp,\n)\nfrom exo.utils.channels import mp_channel\n\nfrom ...constants import (\n    CHAT_COMPLETION_TASK_ID,\n    COMMAND_1_ID,\n    INITIALIZATION_TASK_ID,\n    INSTANCE_1_ID,\n    LOAD_TASK_ID,\n    MODEL_A_ID,\n    NODE_A,\n    RUNNER_1_ID,\n    SHUTDOWN_TASK_ID,\n    WARMUP_TASK_ID,\n)\nfrom ..conftest import get_bound_mlx_ring_instance\n\n\ndef make_nothin[T, U, V](res: T) -> Callable[[], T]:\n    def nothin(*_1: U, **_2: V) -> T:\n        return res\n\n    return nothin\n\n\nnothin = make_nothin(None)\n\n\nINIT_TASK = ConnectToGroup(\n    task_id=INITIALIZATION_TASK_ID,\n    instance_id=INSTANCE_1_ID,\n)\n\nLOAD_TASK = LoadModel(\n    task_id=LOAD_TASK_ID,\n    instance_id=INSTANCE_1_ID,\n)\n\nWARMUP_TASK = StartWarmup(\n    task_id=WARMUP_TASK_ID,\n    instance_id=INSTANCE_1_ID,\n)\n\nSHUTDOWN_TASK = Shutdown(\n    task_id=SHUTDOWN_TASK_ID,\n    instance_id=INSTANCE_1_ID,\n    runner_id=RUNNER_1_ID,\n)\n\nCHAT_PARAMS = TextGenerationTaskParams(\n    model=MODEL_A_ID,\n    input=[InputMessage(role=\"user\", content=\"hello\")],\n    stream=True,\n    max_output_tokens=4,\n    temperature=0.0,\n)\n\nCHAT_TASK = TextGeneration(\n    task_id=CHAT_COMPLETION_TASK_ID,\n    command_id=COMMAND_1_ID,\n    task_params=CHAT_PARAMS,\n    instance_id=INSTANCE_1_ID,\n)\n\n\ndef assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Event]):\n    for test_event, true_event in zip(test_events, true_events, strict=True):\n        test_event.event_id = true_event.event_id\n        assert test_event == true_event, f\"{test_event} != {true_event}\"\n\n\n@pytest.fixture\ndef patch_out_mlx(monkeypatch: pytest.MonkeyPatch):\n    # initialize_mlx returns a mock group\n    monkeypatch.setattr(mlx_runner, \"initialize_mlx\", make_nothin(MockGroup()))\n    monkeypatch.setattr(mlx_runner, \"load_mlx_items\", make_nothin((1, MockTokenizer)))\n    monkeypatch.setattr(mlx_batch_generator, \"warmup_inference\", make_nothin(1))\n    monkeypatch.setattr(mlx_batch_generator, \"_check_for_debug_prompts\", nothin)\n    monkeypatch.setattr(mlx_batch_generator, \"mx_any\", make_nothin(False))\n\n    def fake_all_gather(\n        tasks: list[TextGeneration], group: object\n    ) -> tuple[list[TextGeneration], list[TextGeneration]]:\n        return (tasks, [])\n\n    monkeypatch.setattr(mlx_batch_generator, \"mx_all_gather_tasks\", fake_all_gather)\n    # Mock apply_chat_template since we're using a fake tokenizer (integer 1).\n    # Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.\n    monkeypatch.setattr(\n        mlx_batch_generator, \"apply_chat_template\", make_nothin(\"test prompt\")\n    )\n    monkeypatch.setattr(\n        mlx_model_output_parsers, \"detect_thinking_prompt_suffix\", make_nothin(False)\n    )\n    monkeypatch.setattr(mlx_batch_generator, \"ExoBatchGenerator\", FakeExoBatchGenerator)\n\n\nclass FakeExoBatchGenerator:\n    def __init__(self, *_args: object, **_kwargs: object) -> None:\n        self._uid_counter = 0\n        self._pending: dict[int, GenerationResponse] = {}\n\n    @property\n    def has_work(self) -> bool:\n        return bool(self._pending)\n\n    def submit(\n        self,\n        task_params: object = None,\n        prompt: object = None,\n        on_prefill_progress: object = None,\n        distributed_prompt_progress_callback: object = None,\n        on_generation_token: object = None,\n    ) -> int:\n        uid = self._uid_counter\n        self._uid_counter += 1\n        self._pending[uid] = GenerationResponse(\n            text=\"hi\",\n            token=0,\n            finish_reason=\"stop\",\n            usage=None,\n        )\n        return uid\n\n    def step(self) -> list[tuple[int, GenerationResponse]]:\n        results = list(self._pending.items())\n        self._pending.clear()\n        return results\n\n    def cancel(self, uids: list[int]) -> None:\n        for uid in uids:\n            self._pending.pop(uid, None)\n\n    def close(self) -> None:\n        pass\n\n\n# Use a fake event_sender to remove test flakiness.\nclass EventCollector:\n    def __init__(self, on_event: Callable[[Event], None] | None = None) -> None:\n        self.events: list[Event] = []\n        self._on_event = on_event\n\n    def send(self, event: Event) -> None:\n        self.events.append(event)\n        if self._on_event:\n            self._on_event(event)\n\n    def close(self) -> None:\n        pass\n\n    def join(self) -> None:\n        pass\n\n\nclass MockTokenizer:\n    tool_parser = None\n    tool_call_start = None\n    tool_call_end = None\n    has_tool_calling = False\n    has_thinking = False\n    think_start = None\n    think_end = None\n    eos_token_ids: list[int] = []\n\n    @staticmethod\n    def decode(_tokens: list[int]) -> str:\n        return \"hi\"\n\n    @staticmethod\n    def encode(_text: str, add_special_tokens: bool = True) -> list[int]:\n        return [0]\n\n\nclass MockGroup:\n    def rank(self) -> int:\n        return 0\n\n    def size(self) -> int:\n        return 1\n\n\ndef _run(tasks: Iterable[Task], send_after_ready: list[Task] | None = None):\n    bound_instance = get_bound_mlx_ring_instance(\n        instance_id=INSTANCE_1_ID,\n        model_id=MODEL_A_ID,\n        runner_id=RUNNER_1_ID,\n        node_id=NODE_A,\n    )\n\n    task_sender, task_receiver = mp_channel[Task]()\n    _cancel_sender, cancel_receiver = mp_channel[TaskId]()\n\n    on_event: Callable[[Event], None] | None = None\n    if send_after_ready:\n        _saw_running = False\n\n        def _on_event(event: Event) -> None:\n            nonlocal _saw_running\n            if isinstance(event, RunnerStatusUpdated):\n                if isinstance(event.runner_status, RunnerRunning):\n                    _saw_running = True\n                elif _saw_running and isinstance(event.runner_status, RunnerReady):\n                    for t in send_after_ready:\n                        task_sender.send(t)\n\n        on_event = _on_event\n\n    event_sender = EventCollector(on_event=on_event)\n\n    with task_sender:\n        for t in tasks:\n            task_sender.send(t)\n\n        # worst monkeypatch known to man\n        # this is some c++ nonsense\n        task_receiver.close = nothin\n        task_receiver.join = nothin\n        with unittest.mock.patch(\n            \"exo.worker.runner.llm_inference.runner.mx.distributed.all_gather\",\n            make_nothin(mx.array([1])),\n        ):\n            runner = mlx_runner.Runner(\n                bound_instance,\n                event_sender,  # pyright: ignore[reportArgumentType]\n                task_receiver,\n                cancel_receiver,\n            )\n            runner.main()\n\n        return event_sender.events\n\n\ndef test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):\n    events = _run(\n        [INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK],\n        send_after_ready=[SHUTDOWN_TASK],\n    )\n\n    expected_chunk = ChunkGenerated(\n        command_id=COMMAND_1_ID,\n        chunk=TokenChunk(\n            model=MODEL_A_ID,\n            text=\"hi\",\n            token_id=0,\n            finish_reason=\"stop\",\n            usage=None,\n            stats=None,\n        ),\n    )\n\n    assert_events_equal(\n        events,\n        [\n            RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerIdle()),\n            TaskStatusUpdated(\n                task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running\n            ),\n            RunnerStatusUpdated(\n                runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()\n            ),\n            TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),\n            TaskStatusUpdated(\n                task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete\n            ),\n            RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),\n            TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),\n            RunnerStatusUpdated(\n                runner_id=RUNNER_1_ID,\n                runner_status=RunnerLoading(layers_loaded=0, total_layers=32),\n            ),\n            TaskAcknowledged(task_id=LOAD_TASK_ID),\n            TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),\n            RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),\n            TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),\n            RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),\n            TaskAcknowledged(task_id=WARMUP_TASK_ID),\n            TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),\n            RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),\n            TaskStatusUpdated(\n                task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running\n            ),\n            RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),\n            TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),\n            expected_chunk,\n            TaskStatusUpdated(\n                task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete\n            ),\n            # CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY\n            RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),\n            TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),\n            RunnerStatusUpdated(\n                runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()\n            ),\n            TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),\n            TaskStatusUpdated(\n                task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete\n            ),\n            # SPECIAL EXCEPTION FOR RUNNER SHUTDOWN\n            RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),\n        ],\n    )\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_runner/test_glm_tool_parsing.py",
    "content": "\"\"\"Tests for GLM tool call argument parsing regex.\"\"\"\n\nimport regex as re\n\n# Replicate the regex patterns from runner.py to test them in isolation\n_func_name_regex = re.compile(r\"^(.*?)<arg_key>\", re.DOTALL)\n_func_arg_regex = re.compile(\n    r\"<arg_key>(.*?)</arg_key>(?:\\n|\\s)*<arg_value>(.*?)(?:</arg_value>|(?=<arg_key>)|$)\",\n    re.DOTALL,\n)\n\n\ndef _parse_args(text: str) -> list[tuple[str, str]]:\n    \"\"\"Extract (key, value) pairs from GLM tool call text.\"\"\"\n    pairs = _func_arg_regex.findall(text)\n    return [(k.strip(), v.strip()) for k, v in pairs]  # pyright: ignore[reportAny]\n\n\ndef _parse_func_name(text: str) -> str:\n    \"\"\"Extract function name from GLM tool call text.\"\"\"\n    match = _func_name_regex.search(text)\n    if match is None:\n        raise ValueError(f\"Could not parse function name: {text!r}\")\n    return match.group(1).strip()\n\n\nclass TestGlmToolParsingWithClosingTags:\n    \"\"\"Tests for normal format with closing tags present.\"\"\"\n\n    def test_single_argument(self):\n        text = (\n            \"get_weather<arg_key>location</arg_key><arg_value>San Francisco</arg_value>\"\n        )\n        assert _parse_func_name(text) == \"get_weather\"\n        pairs = _parse_args(text)\n        assert pairs == [(\"location\", \"San Francisco\")]\n\n    def test_multiple_arguments(self):\n        text = (\n            \"search<arg_key>query</arg_key><arg_value>python</arg_value>\"\n            \"<arg_key>limit</arg_key><arg_value>10</arg_value>\"\n        )\n        assert _parse_func_name(text) == \"search\"\n        pairs = _parse_args(text)\n        assert pairs == [(\"query\", \"python\"), (\"limit\", \"10\")]\n\n    def test_arguments_with_whitespace_between(self):\n        text = (\n            \"fn<arg_key>a</arg_key>\\n<arg_value>1</arg_value>\\n\"\n            \"<arg_key>b</arg_key> <arg_value>2</arg_value>\"\n        )\n        pairs = _parse_args(text)\n        assert pairs == [(\"a\", \"1\"), (\"b\", \"2\")]\n\n\nclass TestGlmToolParsingMissingClosingTags:\n    \"\"\"Tests for format where </arg_value> closing tags are missing.\"\"\"\n\n    def test_single_argument_no_closing(self):\n        text = \"get_weather<arg_key>location</arg_key><arg_value>San Francisco\"\n        assert _parse_func_name(text) == \"get_weather\"\n        pairs = _parse_args(text)\n        assert pairs == [(\"location\", \"San Francisco\")]\n\n    def test_multiple_arguments_no_closing(self):\n        text = (\n            \"search<arg_key>query</arg_key><arg_value>python\"\n            \"<arg_key>limit</arg_key><arg_value>10\"\n        )\n        assert _parse_func_name(text) == \"search\"\n        pairs = _parse_args(text)\n        assert pairs == [(\"query\", \"python\"), (\"limit\", \"10\")]\n\n    def test_mixed_closing_tags(self):\n        \"\"\"First arg has closing tag, second does not.\"\"\"\n        text = (\n            \"fn<arg_key>a</arg_key><arg_value>1</arg_value>\"\n            \"<arg_key>b</arg_key><arg_value>2\"\n        )\n        pairs = _parse_args(text)\n        assert pairs == [(\"a\", \"1\"), (\"b\", \"2\")]\n\n    def test_value_with_trailing_whitespace(self):\n        text = \"fn<arg_key>x</arg_key><arg_value>hello world  \\n\"\n        pairs = _parse_args(text)\n        assert pairs == [(\"x\", \"hello world\")]\n\n    def test_value_with_newlines_no_closing(self):\n        text = \"fn<arg_key>data</arg_key><arg_value>line1\\nline2\"\n        pairs = _parse_args(text)\n        assert pairs == [(\"data\", \"line1\\nline2\")]\n\n\nclass TestGlmToolParsingEdgeCases:\n    \"\"\"Edge case tests for GLM tool call parsing.\"\"\"\n\n    def test_empty_value_with_closing(self):\n        text = \"fn<arg_key>empty</arg_key><arg_value></arg_value>\"\n        pairs = _parse_args(text)\n        assert pairs == [(\"empty\", \"\")]\n\n    def test_value_with_json_content(self):\n        text = 'fn<arg_key>data</arg_key><arg_value>{\"key\": \"value\"}</arg_value>'\n        pairs = _parse_args(text)\n        assert pairs == [(\"data\", '{\"key\": \"value\"}')]\n\n    def test_value_with_json_no_closing(self):\n        text = 'fn<arg_key>data</arg_key><arg_value>{\"key\": \"value\"}'\n        pairs = _parse_args(text)\n        assert pairs == [(\"data\", '{\"key\": \"value\"}')]\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_runner/test_parse_gpt_oss.py",
    "content": "from collections.abc import Generator\n\nfrom exo.api.types import FinishReason\nfrom exo.shared.types.worker.runner_response import (\n    GenerationResponse,\n    ToolCallResponse,\n)\nfrom exo.worker.runner.llm_inference.model_output_parsers import parse_gpt_oss\n\n# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.\n# These are stable since they come from the model's vocabulary.\n_CHANNEL = 200005  # <|channel|>\n_START = 200006  # <|start|>\n_MESSAGE = 200008  # <|message|>\n_CALL = 200012  # <|call|>\n_END = 200007  # <|end|>\n_ASSISTANT = 173781  # \"assistant\"\n\n# fmt: off\n# \" to=functions.get_current_weather<|channel|>commentary json<|message|>{\\\"location\\\": \\\"Tokyo\\\"}<|call|>\"\nFORMAT_A_TOKENS: list[tuple[int, str]] = [\n    (316,    \" to\"),\n    (28,     \"=\"),\n    (44580,  \"functions\"),\n    (775,    \".get\"),\n    (23981,  \"_current\"),\n    (170154, \"_weather\"),\n    (_CHANNEL, \"<|channel|>\"),\n    (12606,  \"comment\"),\n    (815,    \"ary\"),\n    (5701,   \" json\"),\n    (_MESSAGE, \"<|message|>\"),\n    (10848,  '{\"'),\n    (7693,   \"location\"),\n    (1243,   '\":'),\n    (392,    ' \"'),\n    (173844, \"Tokyo\"),\n    (18583,  '\"}'),\n    (_CALL,  \"<|call|>\"),\n]\n\n# \"<|channel|>commentary to=functions.get_current_weather json<|message|>{\\\"location\\\": \\\"Tokyo\\\"}<|call|>\"\nFORMAT_B_TOKENS: list[tuple[int, str]] = [\n    (_CHANNEL, \"<|channel|>\"),\n    (12606,  \"comment\"),\n    (815,    \"ary\"),\n    (316,    \" to\"),\n    (28,     \"=\"),\n    (44580,  \"functions\"),\n    (775,    \".get\"),\n    (23981,  \"_current\"),\n    (170154, \"_weather\"),\n    (5701,   \" json\"),\n    (_MESSAGE, \"<|message|>\"),\n    (10848,  '{\"'),\n    (7693,   \"location\"),\n    (1243,   '\":'),\n    (392,    ' \"'),\n    (173844, \"Tokyo\"),\n    (18583,  '\"}'),\n    (_CALL,  \"<|call|>\"),\n]\n\n# \"<|channel|>analysis<|message|>Let me think...<|end|><|start|>assistant<|channel|>commentary to=functions.X ...\"\n# Full analysis-then-tool-call as the model actually generates it.\nTHINKING_THEN_TOOL_TOKENS: list[tuple[int, str]] = [\n    (_CHANNEL, \"<|channel|>\"),\n    (35644,  \"analysis\"),\n    (_MESSAGE, \"<|message|>\"),\n    (12845,  \"Let\"),\n    (668,    \" me\"),\n    (2411,   \" think\"),\n    (1078,   \" about\"),\n    (495,    \" this\"),\n    (13,     \".\"),\n    (_END,   \"<|end|>\"),\n    # Model generates a new message header for the tool call:\n    (_START, \"<|start|>\"),\n    (_ASSISTANT, \"assistant\"),\n    *FORMAT_B_TOKENS,\n]\n# fmt: on\n\n\ndef _make_gen_responses(\n    tokens: list[tuple[int, str]],\n    last_finish_reason: FinishReason = \"stop\",\n) -> list[GenerationResponse]:\n    \"\"\"Build GenerationResponse list from (token_id, text) pairs.\"\"\"\n    responses: list[GenerationResponse] = []\n    for i, (tid, text) in enumerate(tokens):\n        is_last = i == len(tokens) - 1\n        responses.append(\n            GenerationResponse(\n                text=text,\n                token=tid,\n                finish_reason=last_finish_reason if is_last else None,\n                usage=None,\n            )\n        )\n    return responses\n\n\ndef _collect(\n    tokens: list[tuple[int, str]],\n    last_finish_reason: FinishReason = \"stop\",\n) -> list[GenerationResponse | ToolCallResponse]:\n    \"\"\"Feed tokens through parse_gpt_oss and collect all yielded responses.\"\"\"\n\n    def _gen() -> Generator[GenerationResponse, None, None]:\n        yield from _make_gen_responses(tokens, last_finish_reason)\n\n    return list(x for x in parse_gpt_oss(_gen()) if x is not None)\n\n\ndef _get_tool_call(\n    results: list[GenerationResponse | ToolCallResponse],\n) -> ToolCallResponse:\n    \"\"\"Extract the single ToolCallResponse from results.\"\"\"\n    tool_calls = [r for r in results if isinstance(r, ToolCallResponse)]\n    assert len(tool_calls) == 1, f\"Expected 1 ToolCallResponse, got {len(tool_calls)}\"\n    return tool_calls[0]\n\n\nclass TestParseGptOssRecipientPlacement:\n    \"\"\"Both Harmony recipient placements must produce identical tool calls.\"\"\"\n\n    def test_format_a_yields_tool_call(self):\n        results = _collect(FORMAT_A_TOKENS)\n        tc = _get_tool_call(results)\n        assert tc.tool_calls[0].name == \"get_current_weather\"\n        assert '\"location\"' in tc.tool_calls[0].arguments\n        assert \"Tokyo\" in tc.tool_calls[0].arguments\n\n    def test_format_b_yields_tool_call(self):\n        results = _collect(FORMAT_B_TOKENS)\n        tc = _get_tool_call(results)\n        assert tc.tool_calls[0].name == \"get_current_weather\"\n        assert '\"location\"' in tc.tool_calls[0].arguments\n        assert \"Tokyo\" in tc.tool_calls[0].arguments\n\n    def test_both_formats_produce_identical_tool_calls(self):\n        tc_a = _get_tool_call(_collect(FORMAT_A_TOKENS))\n        tc_b = _get_tool_call(_collect(FORMAT_B_TOKENS))\n        assert tc_a.tool_calls[0].name == tc_b.tool_calls[0].name\n        assert tc_a.tool_calls[0].arguments == tc_b.tool_calls[0].arguments\n\n\nclass TestParseGptOssThinkingThenToolCall:\n    \"\"\"Analysis (thinking) followed by a tool call must yield both.\"\"\"\n\n    def test_thinking_then_tool_call(self):\n        results = _collect(THINKING_THEN_TOOL_TOKENS)\n\n        # Thinking tokens should have is_thinking=True and no <think> tags\n        thinking_responses = [\n            r for r in results if isinstance(r, GenerationResponse) and r.is_thinking\n        ]\n        thinking_text = \"\".join(r.text for r in thinking_responses)\n        assert \"Let me think about this.\" in thinking_text\n        assert \"<think>\" not in thinking_text\n        assert \"</think>\" not in thinking_text\n\n        # Non-thinking tokens should have is_thinking=False\n        non_thinking = [\n            r\n            for r in results\n            if isinstance(r, GenerationResponse) and not r.is_thinking\n        ]\n        non_thinking_text = \"\".join(r.text for r in non_thinking)\n        assert \"<think>\" not in non_thinking_text\n\n        # And the tool call\n        tc = _get_tool_call(results)\n        assert tc.tool_calls[0].name == \"get_current_weather\"\n        assert \"Tokyo\" in tc.tool_calls[0].arguments\n\n\n# fmt: off\n# Truncated tool call: recipient + channel + message + partial args, no <|call|>\nTRUNCATED_TOOL_CALL_TOKENS: list[tuple[int, str]] = [\n    (316,    \" to\"),\n    (28,     \"=\"),\n    (44580,  \"functions\"),\n    (775,    \".get\"),\n    (23981,  \"_current\"),\n    (170154, \"_weather\"),\n    (_CHANNEL, \"<|channel|>\"),\n    (12606,  \"comment\"),\n    (815,    \"ary\"),\n    (5701,   \" json\"),\n    (_MESSAGE, \"<|message|>\"),\n    (10848,  '{\"'),\n    (7693,   \"location\"),\n    (1243,   '\":'),\n    (392,    ' \"'),\n    (173844, \"Tokyo\"),\n    # No <|call|> — generation truncated here\n]\n\n# Plain text tokens (no tool call)\nPLAIN_TEXT_TOKENS: list[tuple[int, str]] = [\n    (_CHANNEL, \"<|channel|>\"),\n    (35644,  \"analysis\"),\n    (_MESSAGE, \"<|message|>\"),\n    (12845,  \"Let\"),\n    (668,    \" me\"),\n    (2411,   \" think\"),\n    (1078,   \" about\"),\n    (495,    \" this\"),\n    (13,     \".\"),\n    (_END,   \"<|end|>\"),\n    (_START, \"<|start|>\"),\n    (_ASSISTANT, \"assistant\"),\n    (_CHANNEL, \"<|channel|>\"),\n    (12606,  \"comment\"),\n    (815,    \"ary\"),\n    (_MESSAGE, \"<|message|>\"),\n    (9906,   \"Hello\"),\n    (14,     \",\"),\n    (2989,   \" world\"),\n]\n# fmt: on\n\n\nclass TestParseGptOssMaxTokensTruncation:\n    \"\"\"Truncated tool calls must still yield finish_reason.\"\"\"\n\n    def test_truncated_tool_call_yields_finish_reason(self):\n        results = _collect(TRUNCATED_TOOL_CALL_TOKENS, last_finish_reason=\"length\")\n        gen_responses = [r for r in results if isinstance(r, GenerationResponse)]\n        finish_reasons = [\n            r.finish_reason for r in gen_responses if r.finish_reason is not None\n        ]\n        assert \"length\" in finish_reasons\n\n    def test_truncated_tool_call_emits_partial_args(self):\n        results = _collect(TRUNCATED_TOOL_CALL_TOKENS, last_finish_reason=\"length\")\n        gen_responses = [r for r in results if isinstance(r, GenerationResponse)]\n        last = [r for r in gen_responses if r.finish_reason is not None][-1]\n        assert len(last.text) > 0\n\n    def test_truncated_plain_text_still_works(self):\n        results = _collect(PLAIN_TEXT_TOKENS, last_finish_reason=\"length\")\n        gen_responses = [r for r in results if isinstance(r, GenerationResponse)]\n        finish_reasons = [\n            r.finish_reason for r in gen_responses if r.finish_reason is not None\n        ]\n        assert \"length\" in finish_reasons\n        # Verify non-empty text was yielded (delta text differs from raw token text\n        # due to Harmony encoding, so we just check something was emitted)\n        all_text = \"\".join(r.text for r in gen_responses)\n        assert len(all_text) > 0\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_runner/test_parse_tool_calls.py",
    "content": "\"\"\"Tests for parse_tool_calls generator, especially unclosed tool call handling.\"\"\"\n\nimport json\nfrom collections.abc import Generator\nfrom typing import Any\n\nfrom exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse\nfrom exo.worker.runner.llm_inference.model_output_parsers import parse_tool_calls\nfrom exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser\n\n\ndef _make_responses(\n    texts: list[str],\n    finish_on_last: bool = True,\n) -> Generator[GenerationResponse]:\n    \"\"\"Create a sequence of GenerationResponses from text strings.\"\"\"\n    for i, text in enumerate(texts):\n        is_last = i == len(texts) - 1\n        yield GenerationResponse(\n            text=text,\n            token=i,\n            finish_reason=\"stop\" if (is_last and finish_on_last) else None,\n            usage=None,\n        )\n\n\ndef _dummier_parser(text: str) -> dict[str, Any]:\n    return {\"name\": \"test_fn\", \"arguments\": {\"arg\": text}}\n\n\n_dummy_parser = make_mlx_parser(\"<tool_call>\", \"</tool_call>\", _dummier_parser)\n\n\nclass TestParseToolCalls:\n    \"\"\"Tests for parse_tool_calls generator.\"\"\"\n\n    def test_closed_tool_call_works_normally(self):\n        \"\"\"Normal tool call flow should not be affected.\"\"\"\n        texts = [\"<tool_call>\", \"test_fn\", \"</tool_call>\"]\n        results = list(\n            parse_tool_calls(\n                _make_responses(texts, finish_on_last=False),\n                _dummy_parser,\n                tools=None,\n            )\n        )\n\n        assert len(results) == 1\n        assert isinstance(results[0], ToolCallResponse)\n\n    def test_no_tool_call_passes_through(self):\n        \"\"\"Responses without tool calls should pass through unchanged.\"\"\"\n        texts = [\"Hello\", \" world\"]\n        results = list(\n            parse_tool_calls(\n                _make_responses(texts),\n                _dummy_parser,\n                tools=None,\n            )\n        )\n\n        assert len(results) == 2\n        assert all(isinstance(r, GenerationResponse) for r in results)\n        r0 = results[0]\n        r1 = results[1]\n        assert isinstance(r0, GenerationResponse)\n        assert isinstance(r1, GenerationResponse)\n        assert r0.text == \"Hello\"\n        assert r1.text == \" world\"\n        assert r1.finish_reason == \"stop\"\n\n    def test_failed_parse_yields_text(self):\n        \"\"\"When tool call parsing fails, the text should be yielded as-is.\"\"\"\n\n        def _failing_parser(text: str) -> dict[str, Any]:\n            raise ValueError(\"parse failed\")\n\n        texts = [\"<tool_call>\", \"bad content\", \"</tool_call>\"]\n        results = list(\n            parse_tool_calls(\n                _make_responses(texts, finish_on_last=False),\n                make_mlx_parser(\"<tool_call>\", \"</tool_call>\", _failing_parser),\n                tools=None,\n            )\n        )\n\n        assert len(results) == 1\n        assert isinstance(results[0], GenerationResponse)\n        assert results[0].text == \"<tool_call>bad content</tool_call>\"\n\n    def test_tool_schema_coerces_string_arguments_to_expected_types(self):\n        \"\"\"Tool argument values should be coerced using provided JSON schema.\"\"\"\n\n        def _parser_with_string_args(_text: str) -> dict[str, Any]:\n            return {\n                \"name\": \"process\",\n                \"arguments\": {\n                    \"action\": \"output\",\n                    \"id\": \"0\",\n                    \"verbose\": \"true\",\n                    \"temperature\": \"0.75\",\n                },\n            }\n\n        tools = [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"process\",\n                    \"description\": \"Manage background processes\",\n                    \"parameters\": {\n                        \"type\": \"object\",\n                        \"properties\": {\n                            \"action\": {\"type\": \"string\"},\n                            \"id\": {\"type\": \"integer\"},\n                            \"verbose\": {\"type\": \"boolean\"},\n                            \"temperature\": {\"type\": \"number\"},\n                        },\n                        \"required\": [\"action\"],\n                    },\n                },\n            }\n        ]\n\n        results = list(\n            parse_tool_calls(\n                _make_responses([\"<tool_call>\", \"process\", \"</tool_call>\"]),\n                make_mlx_parser(\n                    \"<tool_call>\", \"</tool_call>\", _parser_with_string_args\n                ),\n                tools,\n            )\n        )\n\n        assert len(results) == 1\n        assert isinstance(results[0], ToolCallResponse)\n\n        args = json.loads(results[0].tool_calls[0].arguments)  # pyright: ignore[reportAny]\n        assert args == {\n            \"action\": \"output\",\n            \"id\": 0,\n            \"verbose\": True,\n            \"temperature\": 0.75,\n        }\n\n    def test_schema_coercion_skips_unknown_tools(self):\n        \"\"\"If no matching tool schema exists, arguments should remain unchanged.\"\"\"\n\n        def _parser_with_string_id(_text: str) -> dict[str, Any]:\n            return {\n                \"name\": \"process\",\n                \"arguments\": {\"action\": \"output\", \"id\": \"0\"},\n            }\n\n        tools = [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"different_tool\",\n                    \"parameters\": {\n                        \"type\": \"object\",\n                        \"properties\": {\"id\": {\"type\": \"integer\"}},\n                    },\n                },\n            }\n        ]\n\n        results = list(\n            parse_tool_calls(\n                _make_responses([\"<tool_call>\", \"process\", \"</tool_call>\"]),\n                make_mlx_parser(\"<tool_call>\", \"</tool_call>\", _parser_with_string_id),\n                tools,\n            )\n        )\n\n        assert len(results) == 1\n        assert isinstance(results[0], ToolCallResponse)\n\n        args = json.loads(results[0].tool_calls[0].arguments)  # pyright: ignore[reportAny]\n        assert args == {\"action\": \"output\", \"id\": \"0\"}\n"
  },
  {
    "path": "src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py",
    "content": "import multiprocessing as mp\nfrom typing import cast\n\nimport anyio\nimport pytest\n\nfrom exo.shared.models.model_cards import ModelId\nfrom exo.shared.types.chunks import ErrorChunk\nfrom exo.shared.types.common import CommandId, NodeId\nfrom exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated\nfrom exo.shared.types.tasks import Task, TaskId, TextGeneration\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\nfrom exo.shared.types.worker.instances import BoundInstance, InstanceId\nfrom exo.shared.types.worker.runners import RunnerFailed, RunnerId\nfrom exo.utils.channels import channel, mp_channel\nfrom exo.worker.runner.runner_supervisor import RunnerSupervisor\nfrom exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance\n\n\nclass _DeadProcess:\n    exitcode = -6\n\n    def start(self) -> None:\n        return None\n\n    def is_alive(self) -> bool:\n        return False\n\n    def join(self, _timeout: float | None = None) -> None:\n        return None\n\n    def terminate(self) -> None:\n        return None\n\n    def kill(self) -> None:\n        return None\n\n\n@pytest.mark.asyncio\nasync def test_check_runner_emits_error_chunk_for_inflight_text_generation() -> None:\n    event_sender, event_receiver = channel[Event]()\n    task_sender, _ = mp_channel[Task]()\n    cancel_sender, _ = mp_channel[TaskId]()\n    _, ev_recv = mp_channel[Event]()\n\n    bound_instance: BoundInstance = get_bound_mlx_ring_instance(\n        instance_id=InstanceId(\"instance-a\"),\n        model_id=ModelId(\"mlx-community/Llama-3.2-1B-Instruct-4bit\"),\n        runner_id=RunnerId(\"runner-a\"),\n        node_id=NodeId(\"node-a\"),\n    )\n\n    supervisor = RunnerSupervisor(\n        shard_metadata=bound_instance.bound_shard,\n        bound_instance=bound_instance,\n        runner_process=cast(\"mp.Process\", cast(object, _DeadProcess())),\n        initialize_timeout=400,\n        _ev_recv=ev_recv,\n        _task_sender=task_sender,\n        _event_sender=event_sender,\n        _cancel_sender=cancel_sender,\n    )\n\n    command_id = CommandId(\"cmd-a\")\n    task = TextGeneration(\n        task_id=TaskId(\"task-a\"),\n        instance_id=bound_instance.instance.instance_id,\n        command_id=command_id,\n        task_params=TextGenerationTaskParams(\n            model=bound_instance.bound_shard.model_card.model_id,\n            input=[InputMessage(role=\"user\", content=\"hi\")],\n            stream=True,\n        ),\n    )\n    supervisor.in_progress[task.task_id] = task\n    supervisor.shutdown = lambda: None\n\n    await supervisor._check_runner(RuntimeError(\"boom\"))  # pyright: ignore[reportPrivateUsage]\n\n    got_chunk = await event_receiver.receive()\n    got_status = await event_receiver.receive()\n\n    assert isinstance(got_chunk, ChunkGenerated)\n    assert got_chunk.command_id == command_id\n    assert isinstance(got_chunk.chunk, ErrorChunk)\n    assert \"Runner shutdown before completing command\" in got_chunk.chunk.error_message\n\n    assert isinstance(got_status, RunnerStatusUpdated)\n    assert isinstance(got_status.runner_status, RunnerFailed)\n\n    event_sender.close()\n    with anyio.move_on_after(0.1):\n        await event_receiver.aclose()\n"
  },
  {
    "path": "tests/auto_bench.sh",
    "content": "#!/usr/bin/env bash\n\n[ $# -lt 1 ] && {\n  echo \"Usage: $0 host1 [host2 ...]\"\n  exit 1\n}\n\n[ -z \"$(git status --porcelain)\" ] || {\n  echo \"Uncommitted changes\"\n  exit 1\n}\n\ncommit=$(git rev-parse HEAD)\ngit fetch -q origin\ngit branch -r --contains \"$commit\" | grep -qE '^\\s*origin/' || {\n  echo \"Not pushed to origin\"\n  exit 1\n}\nhosts=(\"$@\")\n\nfor host; do\n  ssh -T -o BatchMode=yes -o ServerAliveInterval=30 \"$host@$host\" \\\n    \"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit\" &\ndone\nwait\n\ncleanup() {\n  for host in \"${hosts[@]}\"; do\n    ssh -T -o BatchMode=yes \"$host@$host\" \"pkill -f bin/exo\" &\n  done\n  sleep 1\n  jobs -pr | xargs -r kill 2>/dev/null || true\n}\ntrap 'cleanup' EXIT INT TERM\n\nfor host; do\n  ssh -T -o BatchMode=yes -o ServerAliveInterval=30 \"$host@$host\" \\\n    \"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit\" &>/dev/null &\ndone\n\nfor host; do\n  echo \"Waiting for $host...\" 1>&2\n  until curl -sf \"http://$host:52415/models\" &>/dev/null; do sleep 1; done\ndone\n\necho \"Waiting 30s for cluster setup\" 1>&2\nsleep 30\necho \"EXO loaded\" 1>&2\nbench_runner=\"${hosts[0]}\"\nmkdir -p \"./bench/$commit\"\nnix run .#exo-get-all-models-on-cluster -- \"$bench_runner\" | while IFS= read -r model; do\n  echo \"running bench for $model\" 1>&2\n  ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 \"$bench_runner@$bench_runner\" \"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-bench -- --model $model --pp 128 4096 --tg 128 --stdout --skip-tensor-ring\" >>\"./bench/$commit/${model//\\//--}.json\"\n  echo\ndone\n"
  },
  {
    "path": "tests/eval_tool_calls.sh",
    "content": "#!/usr/bin/env bash\n\n[ $# -lt 1 ] && {\n  echo \"Usage: $0 host1 [host2 ...]\"\n  exit 1\n}\n\n[ -z \"$(git status --porcelain)\" ] || {\n  echo \"Uncommitted changes\"\n  exit 1\n}\n\ncommit=$(git rev-parse HEAD)\ngit fetch -q origin\ngit branch -r --contains \"$commit\" | grep -qE '^\\s*origin/' || {\n  echo \"Not pushed to origin\"\n  exit 1\n}\nhosts=(\"$@\")\ncleanup() {\n  for host in \"${hosts[@]}\"; do\n    ssh -T -o BatchMode=yes \"$host@$host\" \"pkill -f bin/exo\" &\n  done\n  sleep 1\n  jobs -pr | xargs -r kill 2>/dev/null || true\n}\ntrap 'cleanup' EXIT INT TERM\n\nfor host; do\n  ssh -T -o BatchMode=yes -o ServerAliveInterval=30 \"$host@$host\" \\\n    \"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix build github:exo-explore/exo/$commit\" &\ndone\nwait\nfor host; do\n  ssh -T -o BatchMode=yes -o ServerAliveInterval=30 \"$host@$host\" \\\n    \"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit\" &>/dev/null &\ndone\n\nfor host; do\n  echo \"Waiting for $host...\" 1>&2\n  until curl -sf \"http://$host:52415/models\" &>/dev/null; do sleep 1; done\ndone\n\necho \"Waiting 30s for cluster setup\" 1>&2\nsleep 30\necho \"EXO loaded\" 1>&2\neval_runner=\"${hosts[0]}\"\nmkdir -p \"./bench/$commit\"\nnix run .#exo-get-all-models-on-cluster -- \"$eval_runner\" | while IFS= read -r model; do\n  echo \"running eval for $model\" 1>&2\n  ssh -Tn -o BatchMode=yes -o ServerAliveInterval=30 \"$eval_runner@$eval_runner\" \\\n    \"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit#exo-eval-tool-calls -- --model $model --stdout\" \\\n    >>\"./bench/$commit/${model//\\//--}-eval.json\"\n  echo\ndone\n"
  },
  {
    "path": "tests/get_all_models_on_cluster.py",
    "content": "#!/usr/bin/env python3\n# pyright: reportAny=false\nimport json\nimport subprocess\nimport sys\nfrom typing import Any, cast\nfrom urllib.request import urlopen\n\nh = sys.argv[1] if len(sys.argv) > 1 else sys.exit(f\"USAGE: {sys.argv[0]} host\")\nts = subprocess.run(\n    [\"tailscale\", \"status\"], check=True, text=True, capture_output=True\n).stdout.splitlines()\nip = next(\n    (sl[0] for line in ts if len(sl := line.split()) >= 2 if sl[1] == h), None\n) or sys.exit(f\"{h} not found in tailscale\")\nwith urlopen(f\"http://{ip}:52415/state\", timeout=5) as r:\n    data = json.loads(r.read()).get(\"downloads\", {})\n\n\ndef mid(x: dict[str, Any]) -> str | None:\n    for k in (\n        \"DownloadCompleted\",\n        \"shardMetadata\",\n        \"PipelineShardMetadata\",\n        \"modelCard\",\n        \"modelId\",\n    ):\n        x = x.get(k, {})\n    return cast(str | None, x if x != {} else None)\n\n\ncommon = set[str].intersection(\n    *[{m for d in nid if (m := mid(d))} for nid in data.values()]\n)\nfor c in common:\n    print(c)\n"
  },
  {
    "path": "tests/headless_runner.py",
    "content": "import socket\nfrom typing import Literal\n\nimport anyio\nfrom fastapi import FastAPI\nfrom fastapi.responses import Response, StreamingResponse\nfrom hypercorn import Config\nfrom hypercorn.asyncio import serve  # pyright: ignore[reportUnknownVariableType]\nfrom loguru import logger\nfrom pydantic import BaseModel\n\nfrom exo.shared.constants import EXO_MODELS_DIR\nfrom exo.shared.models.model_cards import ModelCard, ModelId\nfrom exo.shared.types.chunks import TokenChunk\nfrom exo.shared.types.commands import CommandId\nfrom exo.shared.types.common import Host, NodeId\nfrom exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated\nfrom exo.shared.types.tasks import (\n    ConnectToGroup,\n    LoadModel,\n    Shutdown,\n    StartWarmup,\n    Task,\n    TextGeneration,\n)\nfrom exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams\nfrom exo.shared.types.worker.instances import (\n    BoundInstance,\n    Instance,\n    InstanceId,\n    MlxJacclInstance,\n    MlxRingInstance,\n)\nfrom exo.shared.types.worker.runners import (\n    RunnerFailed,\n    RunnerId,\n    RunnerShutdown,\n    ShardAssignments,\n)\nfrom exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata\nfrom exo.utils.channels import channel, mp_channel\nfrom exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer\nfrom exo.worker.runner.bootstrap import entrypoint\n\n\nclass Tests(BaseModel):\n    # list[hostname, ip addr]\n    devs: list[list[str]]\n    ibv_devs: list[list[str | None]] | None\n    model_id: ModelId\n    kind: Literal[\"ring\", \"jaccl\", \"both\"]\n\n\niid = InstanceId(\"im testing here\")\n\n\nasync def main():\n    logger.info(\"starting cool server majig\")\n    cfg = Config()\n    cfg.bind = \"0.0.0.0:52414\"\n    # nb: shared.logging needs updating if any of this changes\n    cfg.accesslog = \"-\"\n    cfg.errorlog = \"-\"\n    ev = anyio.Event()\n    app = FastAPI()\n    app.post(\"/run_test\")(run_test)\n    app.post(\"/kill\")(lambda: kill(ev))\n    app.get(\"/tb_detection\")(tb_detection)\n    app.get(\"/models\")(list_models)\n    await serve(\n        app,  # type: ignore\n        cfg,\n        shutdown_trigger=lambda: ev.wait(),\n    )\n\n\ndef kill(ev: anyio.Event):\n    ev.set()\n    return Response(status_code=204)\n\n\nasync def tb_detection():\n    send, recv = channel[GatheredInfo]()\n    ig = InfoGatherer(send)\n    with anyio.move_on_after(1):\n        await ig._monitor_system_profiler_thunderbolt_data()  # pyright: ignore[reportPrivateUsage]\n    with recv:\n        return recv.collect()\n\n\ndef list_models():\n    sent = set[str]()\n    for path in EXO_MODELS_DIR.rglob(\"model-*.safetensors\"):\n        if \"--\" not in path.parent.name:\n            continue\n        name = path.parent.name.replace(\"--\", \"/\")\n        if name in sent:\n            continue\n        sent.add(name)\n        yield ModelId(path.parent.name.replace(\"--\", \"/\"))\n\n\nasync def run_test(test: Tests):\n    weird_hn = socket.gethostname()\n    for dev in test.devs:\n        if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):\n            hn = dev[0]\n            break\n    else:\n        raise ValueError(f\"{weird_hn} not in {test.devs}\")\n\n    async def run():\n        logger.info(f\"testing {test.model_id}\")\n\n        instances: list[Instance] = []\n        if test.kind in [\"ring\", \"both\"]:\n            i = await ring_instance(test, hn)\n            if i is None:\n                yield \"no model found\"\n                return\n            instances.append(i)\n        if test.kind in [\"jaccl\", \"both\"]:\n            i = await jaccl_instance(test)\n            if i is None:\n                yield \"no model found\"\n                return\n            instances.append(i)\n\n        for instance in instances:\n            recv = await execute_test(test, instance, hn)\n\n            str_out = \"\"\n\n            for item in recv:\n                if isinstance(item, ChunkGenerated):\n                    assert isinstance(item.chunk, TokenChunk)\n                    str_out += item.chunk.text\n\n                if isinstance(item, RunnerStatusUpdated) and isinstance(\n                    item.runner_status, (RunnerFailed, RunnerShutdown)\n                ):\n                    yield str_out + \"\\n\"\n                    yield item.model_dump_json() + \"\\n\"\n\n    return StreamingResponse(run())\n\n\nasync def ring_instance(test: Tests, hn: str) -> Instance | None:\n    hbn = [Host(ip=\"198.51.100.0\", port=52417) for _ in test.devs]\n    world_size = len(test.devs)\n    for i in range(world_size):\n        if test.devs[i][0] == hn:\n            hn = test.devs[i][0]\n        hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)\n        hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)\n        hbn[i] = Host(ip=\"0.0.0.0\", port=52417)\n        break\n    else:\n        raise ValueError(f\"{hn} not in {test.devs}\")\n\n    card = await ModelCard.load(test.model_id)\n    instance = MlxRingInstance(\n        instance_id=iid,\n        ephemeral_port=52417,\n        hosts_by_node={NodeId(hn): hbn},\n        shard_assignments=ShardAssignments(\n            model_id=test.model_id,\n            node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},\n            runner_to_shard={\n                RunnerId(test.devs[i][0]): PipelineShardMetadata(\n                    model_card=card,\n                    device_rank=i,\n                    world_size=world_size,\n                    start_layer=(card.n_layers // world_size) * i,\n                    end_layer=min(\n                        card.n_layers, (card.n_layers // world_size) * (i + 1)\n                    ),\n                    n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))\n                    - (card.n_layers // world_size) * i,\n                )\n                for i in range(world_size)\n            },\n        ),\n    )\n\n    return instance\n\n\nasync def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:\n    world_size = len(test.devs)\n    commands: list[Task] = [\n        (LoadModel(instance_id=iid)),\n        (StartWarmup(instance_id=iid)),\n        (\n            TextGeneration(\n                task_params=TextGenerationTaskParams(\n                    model=test.model_id,\n                    instructions=\"You are a helpful assistant\",\n                    input=[\n                        InputMessage(\n                            role=\"user\", content=\"What is the capital of France?\"\n                        )\n                    ],\n                ),\n                command_id=CommandId(\"yo\"),\n                instance_id=iid,\n            )\n        ),\n        (Shutdown(runner_id=RunnerId(hn), instance_id=iid)),\n    ]\n    if world_size > 1:\n        commands.insert(0, ConnectToGroup(instance_id=iid))\n    bound_instance = BoundInstance(\n        instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)\n    )\n    ev_send, _ev_recv = mp_channel[Event]()\n    task_send, task_recv = mp_channel[Task]()\n\n    for command in commands:\n        task_send.send(command)\n\n    entrypoint(\n        bound_instance,\n        ev_send,\n        task_recv,\n        logger,\n    )\n\n    # TODO(evan): return ev_recv.collect()\n    return []\n\n\nasync def jaccl_instance(test: Tests) -> MlxJacclInstance | None:\n    card = await ModelCard.load(test.model_id)\n    world_size = len(test.devs)\n    assert test.ibv_devs\n\n    return MlxJacclInstance(\n        instance_id=iid,\n        jaccl_devices=test.ibv_devs,\n        # rank 0 is always coordinator\n        jaccl_coordinators={\n            NodeId(host[0]): test.devs[0][1] + \":52417\" for host in test.devs\n        },\n        shard_assignments=ShardAssignments(\n            model_id=test.model_id,\n            node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},\n            runner_to_shard={\n                RunnerId(host[0]): TensorShardMetadata(\n                    model_card=card,\n                    device_rank=i,\n                    world_size=world_size,\n                    start_layer=0,\n                    end_layer=card.n_layers,\n                    n_layers=card.n_layers,\n                )\n                for i, host in enumerate(test.devs)\n            },\n        ),\n    )\n\n\nif __name__ == \"__main__\":\n    anyio.run(main)\n"
  },
  {
    "path": "tests/run_exo_on.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\n[ $# -lt 1 ] && {\n  echo \"Usage: $0 host1 [host2 ...]\"\n  exit 1\n}\n\n[ -z \"$(git status --porcelain)\" ] || {\n  echo \"Uncommitted changes\"\n  exit 1\n}\n\ncommit=$(git rev-parse HEAD)\ngit fetch -q origin\ngit branch -r --contains \"$commit\" | grep -qE '^\\s*origin/' || {\n  echo \"Not pushed to origin\"\n  exit 1\n}\n\necho \"Deploying $commit to $# hosts...\"\nhosts=(\"$@\")\ncleanup() {\n  for host in \"${hosts[@]}\"; do\n    ssh -T -o BatchMode=yes \"$host@$host\" \"pkill -f bin/exo\" &\n  done\n  wait\n  jobs -pr | xargs -r kill 2>/dev/null || true\n}\ntrap 'cleanup' EXIT INT TERM\n\ncolours=($'\\e[31m' $'\\e[32m' $'\\e[33m' $'\\e[34m')\nreset=$'\\e[0m'\ni=0\nfor host; do\n  colour=${colours[i++ % 4]}\n  ssh -T -o BatchMode=yes -o ServerAliveInterval=30 \"$host@$host\" \\\n    \"EXO_LIBP2P_NAMESPACE=$commit /nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit\" |&\n    awk -v p=\"${colour}[${host}]${reset}\" '{ print p $0; fflush() }' &\ndone\n\nfor host; do\n  echo \"Waiting for $host...\"\n  until curl -sf \"http://$host:52415/models\" &>/dev/null; do sleep 1; done\ndone\nwait\n"
  },
  {
    "path": "tests/start_distributed_test.py",
    "content": "#!/usr/bin/env python3\nimport itertools\nimport json\nimport subprocess\nimport sys\nfrom concurrent.futures import ThreadPoolExecutor\nfrom typing import Any, cast\nfrom urllib.request import Request, urlopen\n\nif not (args := sys.argv[1:]):\n    sys.exit(\n        f\"USAGE: {sys.argv[0]} <kind> [host1] [host2] ...\\nkind is optional, and should be jaccl or ring\"\n    )\n\nkind = args[0] if args[0] in (\"jaccl\", \"ring\") else \"both\"\nhosts = args[1:] if kind != \"both\" else args\nts = subprocess.run(\n    [\"tailscale\", \"status\"], check=True, text=True, capture_output=True\n).stdout.splitlines()\nip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}\nips = [ip[h] for h in hosts]\ndevs = [[h, ip[h]] for h in hosts]\nn = len(hosts)\n\n\ndef get_tb(a: str) -> list[dict[str, Any]]:\n    with urlopen(f\"http://{a}:52414/tb_detection\", timeout=5) as r:  # pyright: ignore[reportAny]\n        return json.loads(r.read())  # pyright: ignore[reportAny]\n\n\ndef get_models(a: str) -> set[str]:\n    with urlopen(f\"http://{a}:52414/models\", timeout=5) as r:  # pyright: ignore[reportAny]\n        return set(json.loads(r.read()))  # pyright: ignore[reportAny]\n\n\ndef run(h: str, a: str, body: bytes) -> None:\n    with urlopen(\n        Request(\n            f\"http://{a}:52414/run_test\",\n            data=body,\n            method=\"POST\",\n            headers={\"Content-Type\": \"application/json\"},\n        ),\n        timeout=300,\n    ) as r:  # pyright: ignore[reportAny]\n        for line in r.read().decode(errors=\"replace\").splitlines():  # pyright: ignore[reportAny]\n            print(f\"\\n{h}@{a}: {line}\", flush=True)\n\n\nwith ThreadPoolExecutor(n) as exctr:\n    if kind in (\"jaccl\", \"both\"):\n        payloads = list(exctr.map(get_tb, ips))\n\n        u2e = {\n            ident[\"domainUuid\"]: (i, ident[\"rdmaInterface\"])\n            for i, p in enumerate(payloads)\n            for d in p\n            for ident in cast(\n                list[dict[str, str]],\n                d.get(\"MacThunderboltIdentifiers\", {}).get(\"idents\", []),  # pyright: ignore[reportAny]\n            )\n        }\n        edges = {\n            (u2e[s][0], u2e[t][0]): u2e[t][1]\n            for p in payloads\n            for d in p\n            for c in d.get(\"MacThunderboltConnections\", {}).get(\"conns\", [])  # pyright: ignore[reportAny]\n            if (s := c[\"sourceUuid\"]) in u2e and (t := c[\"sinkUuid\"]) in u2e  # pyright: ignore[reportAny]\n        }\n        ibv_devs = [[edges.get((i, j)) for j in range(n)] for i in range(n)]\n    else:\n        ibv_devs = None\n\n    models = set[str].intersection(*exctr.map(get_models, ips))\n\n    print(\"\\n\")\n    print(\"=\" * 70)\n    print(f\"Starting test with {models}\")\n    print(\"=\" * 70)\n    print(\"\\n\")\n    for model in models:\n        body = json.dumps(\n            {\"devs\": devs, \"model_id\": model, \"ibv_devs\": ibv_devs, \"kind\": kind}\n        ).encode()\n        list(exctr.map(run, hosts, ips, itertools.repeat(body)))\n"
  },
  {
    "path": "tmp/config_examples/claude_code.sh",
    "content": "#!/bin/bash\n# Run Claude Code against a local exo cluster! (Here, GPT OSS 120B)\nANTHROPIC_BASE_URL=\"http://localhost:52415/\" \\\n  ANTHROPIC_AUTH_TOKEN=\"dummy\" \\\n  ANTHROPIC_MODEL=\"mlx-community/gpt-oss-120b-MXFP4-Q8\" \\\n  ANTHROPIC_SMALL_FAST_MODEL=\"mlx-community/gpt-oss-120b-MXFP4-Q8\" \\\n  CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 \\\n  claude\n"
  },
  {
    "path": "tmp/config_examples/opencode.json",
    "content": "{\n  \"$schema\": \"https://opencode.ai/config.json\",\n  \"model\": \"exo/mlx-community/gpt-oss-120b-MXFP4-Q8\",\n  \"provider\": {\n    \"exo\": {\n      \"api\": \"http://localhost:52415/v1\",\n      \"models\": {\n        \"mlx-community/gpt-oss-120b-MXFP4-Q8\": {\n          \"name\": \"GPT OSS 120B\",\n          \"limit\": {\n            \"context\": 32768,\n            \"output\": 8192\n          }\n        }\n      }\n    }\n  }\n}\n"
  },
  {
    "path": "tmp/gen_card.py",
    "content": "\"\"\"\nGenerates inference model cards for EXO.\nUsage:\n    uv run tmp/gen_card.py mlx-community/my_cool_model-8bit [repo-id/model-id-2] [...]\n\nModel Cards require cleanup for family & quantization data\n\"\"\"\n\nimport sys\n\nimport anyio\n\nfrom exo.shared.models.model_cards import ModelCard, ModelId\n\n\nasync def main():\n    if len(sys.argv) == 1:\n        print(f\"USAGE: {sys.argv[0]} repo-id/model-id-1 [repo-id/model-id-2] [...]\")\n        quit(1)\n    print(\"Remember! Model Cards require cleanup for family & quantization data\")\n    for arg in sys.argv[1:]:\n        mid = ModelId(arg)\n        mc = await ModelCard.fetch_from_hf(mid)\n        await mc.save(\n            anyio.Path(__file__).parent.parent\n            / \"resources\"\n            / \"inference_model_cards\"\n            / (mid.normalize() + \".toml\")\n        )\n\n\nif __name__ == \"__main__\":\n    anyio.run(main)\n"
  },
  {
    "path": "tmp/prompt.txt",
    "content": "Summarise this Wikipedia article for me:\n\nTransition from Republic to Empire\n\nAugustus of Prima Porta\nRome had begun expanding shortly after the founding of the Roman Republic in the 6th century BC, though not outside the Italian Peninsula until the 3rd century BC. The Republic was not a nation-state in the modern sense, but a network of self-ruled towns (with varying degrees of independence from the Senate) and provinces administered by military commanders. It was governed by annually elected magistrates (Roman consuls above all) in conjunction with the Senate.[22] The 1st century BC was a time of political and military upheaval, which ultimately led to rule by emperors.[23][24][25] The consuls' military power rested in the Roman legal concept of imperium, meaning \"command\" (typically in a military sense).[26] Occasionally, successful consuls or generals were given the honorary title imperator (commander); this is the origin of the word emperor, since this title was always bestowed to the early emperors.[27][g]\n\nRome suffered a long series of internal conflicts, conspiracies, and civil wars from the late second century BC (see Crisis of the Roman Republic) while greatly extending its power beyond Italy. In 44 BC Julius Caesar was briefly perpetual dictator before being assassinated by a faction that opposed his concentration of power. This faction was driven from Rome and defeated at the Battle of Philippi in 42 BC by Mark Antony and Caesar's adopted son Octavian. Antony and Octavian divided the Roman world between them, but this did not last long. Octavian's forces defeated those of Mark Antony and Cleopatra at the Battle of Actium in 31 BC. In 27 BC the Senate gave him the title Augustus (\"venerated\") and made him princeps (\"foremost\") with proconsular imperium, thus beginning the Principate, the first epoch of Roman imperial history. Although the republic stood in name, Augustus had all meaningful authority.[29] During his 40-year rule, a new constitutional order emerged so that, upon his death, Tiberius would succeed him as the new de facto monarch.[30]\n\nPax Romana\nMain article: Pax Romana\nThe so-called \"Five Good Emperors\" of 96–180 AD\n\nNerva (r. 96–98)\n\nTrajan (r. 98–117)\n\nHadrian (r. 117–138)\n\nAntoninus Pius (r. 138–161)\n\nMarcus Aurelius (r. 161–180)\nThe 200 years that began with Augustus's rule are traditionally regarded as the Pax Romana (\"Roman Peace\"). The cohesion of the empire was furthered by a degree of social stability and economic prosperity that Rome had never before experienced. Uprisings in the provinces were infrequent and put down \"mercilessly and swiftly\".[31] The success of Augustus in establishing principles of dynastic succession was limited by his outliving a number of talented potential heirs. The Julio-Claudian dynasty lasted for four more emperors—Tiberius, Caligula, Claudius, and Nero—before it yielded in 69 AD to the strife-torn Year of the Four Emperors, from which Vespasian emerged as the victor. Vespasian became the founder of the brief Flavian dynasty, followed by the Nerva–Antonine dynasty which produced the \"Five Good Emperors\": Nerva, Trajan, Hadrian, Antoninus Pius, and Marcus Aurelius.[32]\n\nAmong the so-called “Five Good Emperors,” Hadrian (r. 117–138) is particularly noted for consolidating the empire’s frontiers and embarking on ambitious building projects throughout the provinces.[33] In Judaea, which had long been the center of Jewish national and religious life, his reign marked a decisive turning point. After earlier Jewish resistance to Roman rule, Hadrian visited the region in 129/130 CE and refounded Jerusalem as the Roman colony Aelia Capitolina, naming it after his family (Aelius) and the Capitoline Triad.[34] The refoundation overlaid the destroyed Jewish city with a new Roman urban plan, and included the construction of a Temple to Jupiter on the site of the former Jewish Temple.[35] Later tradition and archaeological evidence also indicate a Temple of Venus near the site of the Holy Sepulchre.[36]\n\nHadrian’s measures, combined with restrictions on Jewish practices, helped spark the Bar Kokhba Revolt (132–135 CE). After crushing the uprising, Roman forces expelled most Jews from Jerusalem, barring their entry except on certain days, and rebuilt the city as a statement of imperial power and domination.[33] Most scholars consider Hadrianic Aelia to have been unwalled, with free-standing gate complexes (such as the northern gate beneath today’s Damascus Gate) rather than a continuous defensive circuit.[37]\n\nTransition from classical to late antiquity\nMain articles: Later Roman Empire and Fall of the Western Roman Empire\nSee also: Barbarian kingdoms and Byzantine Empire\n\nThe Barbarian invasions consisted of the movement of (mainly) ancient Germanic peoples into Roman territory. Historically, this event marked the transition between classical antiquity and the Middle Ages.\nIn the view of contemporary Greek historian Cassius Dio, the accession of Commodus in 180 marked the descent \"from a kingdom of gold to one of rust and iron\",[38] a comment which has led some historians, notably Edward Gibbon, to take Commodus' reign as the beginning of the Empire's decline.[39][40]\n\nIn 212, during the reign of Caracalla, Roman citizenship was granted to all freeborn inhabitants of the empire. The Severan dynasty was tumultuous; an emperor's reign was ended routinely by his murder or execution and, following its collapse, the Empire was engulfed by the Crisis of the Third Century, a period of invasions, civil strife, economic disorder, and plague.[41] In defining historical epochs, this crisis sometimes marks the transition from Classical to Late Antiquity. Aurelian (r. 270–275) stabilised the empire militarily and Diocletian reorganised and restored much of it in 285.[42] Diocletian's reign brought the empire's most concerted effort against the perceived threat of Christianity, the \"Great Persecution\".[43]\n\nDiocletian divided the empire into four regions, each ruled by a separate tetrarch.[44] Confident that he fixed the disorder plaguing Rome, he abdicated along with his co-emperor, but the Tetrarchy collapsed shortly after. Order was eventually restored by Constantine the Great, who became the first emperor to convert to Christianity, and who established Constantinople as the new capital of the Eastern Empire. During the decades of the Constantinian and Valentinian dynasties, the empire was divided along an east–west axis, with dual power centres in Constantinople and Rome. Julian, who under the influence of his adviser Mardonius attempted to restore Classical Roman and Hellenistic religion, only briefly interrupted the succession of Christian emperors. Theodosius I, the last emperor to rule over both East and West, died in 395 after making Christianity the state religion.[45]\n\n\nThe Roman Empire by 476, noting western and eastern divisions\n\nThe administrative divisions of the Roman Empire in 395 AD\nFall in the West and survival in the East\nThe Western Roman Empire began to disintegrate in the early 5th century. The Romans fought off all invaders, most famously Attila,[46] but the empire had assimilated so many Germanic peoples of dubious loyalty to Rome that the empire started to dismember itself.[47] Most chronologies place the end of the Western Roman Empire in 476, when Romulus Augustulus was forced to abdicate to the Germanic warlord Odoacer.[48][49][50]\n\nOdoacer ended the Western Empire by declaring Zeno sole emperor and placing himself as Zeno's nominal subordinate. In reality, Italy was ruled by Odoacer alone.[48][49][51] The Eastern Roman Empire, called the Byzantine Empire by later historians, continued until the reign of Constantine XI Palaiologos, the last Roman emperor. He died in battle in 1453 against Mehmed II and his Ottoman forces during the siege of Constantinople. Mehmed II adopted the title of caesar in an attempt to claim a connection to the former Empire.[52][53] His claim was soon recognized by the Patriarchate of Constantinople, but not by European monarchs."
  },
  {
    "path": "tmp/quantize_and_upload.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nDownload an mflux model, quantize it, and upload to HuggingFace.\n\nUsage (run from mflux project directory):\n    cd /path/to/mflux\n    uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev\n    uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit\n    uv run python /path/to/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run\n\nRequires:\n    - Must be run from mflux project directory using `uv run`\n    - huggingface_hub installed (add to mflux deps or install separately)\n    - HuggingFace authentication: run `huggingface-cli login` or set HF_TOKEN\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport re\nimport shutil\nimport sys\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\n\nif TYPE_CHECKING:\n    from mflux.models.flux.variants.txt2img.flux import Flux1\n\n\nHF_ORG = \"exolabs\"\n\n\ndef get_model_class(model_name: str) -> type:\n    \"\"\"Get the appropriate model class based on model name.\"\"\"\n    from mflux.models.fibo.variants.txt2img.fibo import FIBO\n    from mflux.models.flux.variants.txt2img.flux import Flux1\n    from mflux.models.flux2.variants.txt2img.flux2_klein import Flux2Klein\n    from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage\n    from mflux.models.z_image.variants.turbo.z_image_turbo import ZImageTurbo\n\n    model_name_lower = model_name.lower()\n    if \"qwen\" in model_name_lower:\n        return QwenImage\n    elif \"fibo\" in model_name_lower:\n        return FIBO\n    elif \"z-image\" in model_name_lower or \"zimage\" in model_name_lower:\n        return ZImageTurbo\n    elif \"flux2\" in model_name_lower or \"flux.2\" in model_name_lower:\n        return Flux2Klein\n    else:\n        return Flux1\n\n\ndef get_repo_name(model_name: str, bits: int | None) -> str:\n    \"\"\"Get the HuggingFace repo name for a model variant.\"\"\"\n    # Extract repo name from HF path (e.g., \"black-forest-labs/FLUX.1-Kontext-dev\" -> \"FLUX.1-Kontext-dev\")\n    base_name = model_name.split(\"/\")[-1] if \"/\" in model_name else model_name\n    suffix = f\"-{bits}bit\" if bits else \"\"\n    return f\"{HF_ORG}/{base_name}{suffix}\"\n\n\ndef get_local_path(output_dir: Path, model_name: str, bits: int | None) -> Path:\n    \"\"\"Get the local save path for a model variant.\"\"\"\n    # Extract repo name from HF path (e.g., \"black-forest-labs/FLUX.1-Kontext-dev\" -> \"FLUX.1-Kontext-dev\")\n    base_name = model_name.split(\"/\")[-1] if \"/\" in model_name else model_name\n    suffix = f\"-{bits}bit\" if bits else \"\"\n    return output_dir / f\"{base_name}{suffix}\"\n\n\ndef copy_source_repo(\n    source_repo: str,\n    local_path: Path,\n    dry_run: bool = False,\n) -> None:\n    \"\"\"Copy all files from source repo (replicating original HF structure).\"\"\"\n    print(f\"\\n{'=' * 60}\")\n    print(f\"Copying full repo from source: {source_repo}\")\n    print(f\"Output path: {local_path}\")\n    print(f\"{'=' * 60}\")\n\n    if dry_run:\n        print(\"[DRY RUN] Would download all files from source repo\")\n        return\n\n    from huggingface_hub import snapshot_download\n\n    # Download all files to our local path\n    snapshot_download(\n        repo_id=source_repo,\n        local_dir=local_path,\n    )\n\n    # Remove root-level safetensors files (flux.1-dev.safetensors, etc.)\n    # These are redundant with the component directories\n    for f in local_path.glob(\"*.safetensors\"):\n        print(f\"Removing root-level safetensors: {f.name}\")\n        if not dry_run:\n            f.unlink()\n\n    print(f\"Source repo copied to {local_path}\")\n\n\ndef load_and_save_quantized_model(\n    model_name: str,\n    bits: int,\n    output_path: Path,\n    dry_run: bool = False,\n) -> None:\n    \"\"\"Load a model with quantization and save it in mflux format.\"\"\"\n    print(f\"\\n{'=' * 60}\")\n    print(f\"Loading {model_name} with {bits}-bit quantization...\")\n    print(f\"Output path: {output_path}\")\n    print(f\"{'=' * 60}\")\n\n    if dry_run:\n        print(\"[DRY RUN] Would load and save quantized model\")\n        return\n\n    from mflux.models.common.config.model_config import ModelConfig\n\n    model_class = get_model_class(model_name)\n    model_config = ModelConfig.from_name(model_name=model_name, base_model=None)\n\n    model: Flux1 = model_class(\n        quantize=bits,\n        model_config=model_config,\n    )\n\n    print(f\"Saving model to {output_path}...\")\n    model.save_model(str(output_path))\n    print(f\"Model saved successfully to {output_path}\")\n\n\ndef copy_source_metadata(\n    source_repo: str,\n    local_path: Path,\n    dry_run: bool = False,\n) -> None:\n    \"\"\"Copy metadata files (LICENSE, README, etc.) from source repo, excluding safetensors.\"\"\"\n    print(f\"\\n{'=' * 60}\")\n    print(f\"Copying metadata from source repo: {source_repo}\")\n    print(f\"{'=' * 60}\")\n\n    if dry_run:\n        print(\"[DRY RUN] Would download metadata files (excluding *.safetensors)\")\n        return\n\n    from huggingface_hub import snapshot_download\n\n    # Download all files except safetensors to our local path\n    snapshot_download(\n        repo_id=source_repo,\n        local_dir=local_path,\n        ignore_patterns=[\"*.safetensors\"],\n    )\n    print(f\"Metadata files copied to {local_path}\")\n\n\ndef upload_to_huggingface(\n    local_path: Path,\n    repo_id: str,\n    dry_run: bool = False,\n    clean_remote: bool = False,\n) -> None:\n    \"\"\"Upload a saved model to HuggingFace.\"\"\"\n    print(f\"\\n{'=' * 60}\")\n    print(f\"Uploading to HuggingFace: {repo_id}\")\n    print(f\"Local path: {local_path}\")\n    print(f\"Clean remote first: {clean_remote}\")\n    print(f\"{'=' * 60}\")\n\n    if dry_run:\n        print(\"[DRY RUN] Would upload to HuggingFace\")\n        return\n\n    from huggingface_hub import HfApi\n\n    api = HfApi()\n\n    # Create the repo if it doesn't exist\n    print(f\"Creating/verifying repo: {repo_id}\")\n    api.create_repo(repo_id=repo_id, repo_type=\"model\", exist_ok=True)\n\n    # Clean remote repo if requested (delete old mflux-format files)\n    if clean_remote:\n        print(\"Cleaning old mflux-format files from remote...\")\n        try:\n            # Pattern for mflux numbered shards: <dir>/<number>.safetensors\n            numbered_pattern = re.compile(r\".*/\\d+\\.safetensors$\")\n\n            repo_files = api.list_repo_files(repo_id=repo_id, repo_type=\"model\")\n            for file_path in repo_files:\n                # Delete numbered safetensors (mflux format) and mflux index files\n                if numbered_pattern.match(file_path) or file_path.endswith(\n                    \"/model.safetensors.index.json\"\n                ):\n                    print(f\"  Deleting: {file_path}\")\n                    api.delete_file(\n                        path_in_repo=file_path, repo_id=repo_id, repo_type=\"model\"\n                    )\n        except Exception as e:\n            print(f\"Warning: Could not clean remote files: {e}\")\n\n    # Upload the folder\n    print(\"Uploading folder contents...\")\n    api.upload_folder(\n        folder_path=str(local_path),\n        repo_id=repo_id,\n        repo_type=\"model\",\n    )\n    print(f\"Upload complete: https://huggingface.co/{repo_id}\")\n\n\ndef clean_local_files(local_path: Path, dry_run: bool = False) -> None:\n    \"\"\"Remove local model files after upload.\"\"\"\n    print(f\"\\nCleaning up: {local_path}\")\n    if dry_run:\n        print(\"[DRY RUN] Would remove local files\")\n        return\n\n    if local_path.exists():\n        shutil.rmtree(local_path)\n        print(f\"Removed {local_path}\")\n\n\ndef main() -> int:\n    parser = argparse.ArgumentParser(\n        description=\"Download an mflux model, quantize it, and upload to HuggingFace.\",\n        formatter_class=argparse.RawDescriptionHelpFormatter,\n        epilog=\"\"\"\nExamples:\n    # Process all variants (base, 4-bit, 8-bit) for FLUX.1-Kontext-dev\n    python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev\n\n    # Only process 4-bit variant\n    python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-base --skip-8bit\n\n    # Save locally without uploading\n    python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --skip-upload\n\n    # Preview what would happen\n    python tmp/quantize_and_upload.py --model black-forest-labs/FLUX.1-Kontext-dev --dry-run\n        \"\"\",\n    )\n\n    parser.add_argument(\n        \"--model\",\n        \"-m\",\n        required=True,\n        help=\"HuggingFace model path (e.g., black-forest-labs/FLUX.1-Kontext-dev)\",\n    )\n    parser.add_argument(\n        \"--output-dir\",\n        type=Path,\n        default=Path(\"./tmp/models\"),\n        help=\"Local directory to save models (default: ./tmp/models)\",\n    )\n    parser.add_argument(\n        \"--skip-base\",\n        action=\"store_true\",\n        help=\"Skip base model (no quantization)\",\n    )\n    parser.add_argument(\n        \"--skip-4bit\",\n        action=\"store_true\",\n        help=\"Skip 4-bit quantized model\",\n    )\n    parser.add_argument(\n        \"--skip-8bit\",\n        action=\"store_true\",\n        help=\"Skip 8-bit quantized model\",\n    )\n    parser.add_argument(\n        \"--skip-download\",\n        action=\"store_true\",\n        help=\"Skip downloading/processing, only do upload/clean operations\",\n    )\n    parser.add_argument(\n        \"--skip-upload\",\n        action=\"store_true\",\n        help=\"Only save locally, don't upload to HuggingFace\",\n    )\n    parser.add_argument(\n        \"--clean\",\n        action=\"store_true\",\n        help=\"Remove local files after upload\",\n    )\n    parser.add_argument(\n        \"--clean-remote\",\n        action=\"store_true\",\n        help=\"Delete old mflux-format files from remote repo before uploading\",\n    )\n    parser.add_argument(\n        \"--dry-run\",\n        action=\"store_true\",\n        help=\"Print actions without executing\",\n    )\n\n    args = parser.parse_args()\n\n    # Determine which variants to process\n    variants: list[int | None] = []\n    if not args.skip_base:\n        variants.append(None)  # Base model (no quantization)\n    if not args.skip_4bit:\n        variants.append(4)\n    if not args.skip_8bit:\n        variants.append(8)\n\n    if not variants:\n        print(\"Error: All variants skipped. Nothing to do.\")\n        return 1\n\n    # Create output directory\n    args.output_dir.mkdir(parents=True, exist_ok=True)\n\n    print(f\"Model: {args.model}\")\n    print(f\"Output directory: {args.output_dir}\")\n    print(\n        f\"Variants to process: {['base' if v is None else f'{v}-bit' for v in variants]}\"\n    )\n    print(f\"Upload to HuggingFace: {not args.skip_upload}\")\n    print(f\"Clean after upload: {args.clean}\")\n    if args.dry_run:\n        print(\"\\n*** DRY RUN MODE - No actual changes will be made ***\")\n\n    # Process each variant\n    for bits in variants:\n        local_path = get_local_path(args.output_dir, args.model, bits)\n        repo_id = get_repo_name(args.model, bits)\n\n        if not args.skip_download:\n            if bits is None:\n                # Base model: copy original HF repo structure (no mflux conversion)\n                copy_source_repo(\n                    source_repo=args.model,\n                    local_path=local_path,\n                    dry_run=args.dry_run,\n                )\n            else:\n                # Quantized model: load, quantize, and save with mflux\n                load_and_save_quantized_model(\n                    model_name=args.model,\n                    bits=bits,\n                    output_path=local_path,\n                    dry_run=args.dry_run,\n                )\n\n                # Copy metadata from source repo (LICENSE, README, etc.)\n                copy_source_metadata(\n                    source_repo=args.model,\n                    local_path=local_path,\n                    dry_run=args.dry_run,\n                )\n\n        # Upload\n        if not args.skip_upload:\n            upload_to_huggingface(\n                local_path=local_path,\n                repo_id=repo_id,\n                dry_run=args.dry_run,\n                clean_remote=args.clean_remote,\n            )\n\n            # Clean up if requested\n            if args.clean:\n                clean_local_files(local_path, dry_run=args.dry_run)\n\n    print(\"\\n\" + \"=\" * 60)\n    print(\"All done!\")\n    print(\"=\" * 60)\n\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "tmp/run_llm.py",
    "content": "#!/usr/bin/env python3\nimport argparse\nimport json\nimport sys\n\nimport requests\n\n\ndef stream_chat(host: str, query: str) -> None:\n    url = f\"http://{host}:52415/v1/chat/completions\"\n    headers = {\"Content-Type\": \"application/json\"}\n    payload = {\n        \"model\": \"mlx-community/Llama-3.2-1B-Instruct-4bit\",\n        # \"model\": \"mlx-community/Llama-3_3-Nemotron-Super-49B-v1_5-mlx-4Bit\",\n        \"stream\": True,\n        \"messages\": [{\"role\": \"user\", \"content\": query}],\n    }\n\n    try:\n        with requests.post(url, headers=headers, json=payload, stream=True) as resp:\n            resp.raise_for_status()\n            for line in resp.iter_lines(decode_unicode=True):\n                if not line:\n                    continue\n\n                # SSE lines look like: \"data: {...}\" or \"data: [DONE]\"\n                if not line.startswith(\"data:\"):\n                    continue\n\n                data = line[len(\"data:\") :].strip()\n                if data == \"[DONE]\":\n                    break\n\n                try:\n                    obj = json.loads(data)\n                except json.JSONDecodeError:\n                    continue\n\n                for choice in obj.get(\"choices\", []):\n                    delta = choice.get(\"delta\") or {}\n                    content = delta.get(\"content\")\n                    if content:\n                        print(content, end=\"\", flush=True)\n\n    except requests.RequestException as e:\n        print(f\"Request failed: {e}\", file=sys.stderr)\n        sys.exit(1)\n\n    print()\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(\n        description=\"Stream chat completions from a local server.\"\n    )\n    parser.add_argument(\"host\", help=\"Hostname (without protocol), e.g. localhost\")\n    parser.add_argument(\n        \"-f\",\n        \"--file\",\n        help=\"Path to a text file whose contents will be used as the query\",\n    )\n    parser.add_argument(\n        \"query\",\n        nargs=\"*\",\n        help=\"Query text (if not using -f/--file). All remaining arguments are joined with spaces.\",\n    )\n\n    args = parser.parse_args()\n\n    if args.file:\n        try:\n            with open(args.file, \"r\", encoding=\"utf-8\") as f:\n                query = f.read().strip()\n        except OSError as e:\n            print(f\"Error reading file {args.file}: {e}\", file=sys.stderr)\n            sys.exit(1)\n    elif args.query:\n        query = \" \".join(args.query)\n    else:\n        parser.error(\"You must provide either a query or a file (-f/--file).\")\n\n    stream_chat(args.host, query)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tmp/run_llm.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\nif [ $# -lt 2 ]; then\n  echo \"Usage: $0 <hostname> <query>\"\n  exit 1\nfi\n\nHOST=\"$1\"\nshift\nQUERY=\"$*\"\n\ncurl -sN -X POST \"http://$HOST:52415/v1/chat/completions\" \\\n  -H \"Content-Type: application/json\" \\\n  -d \"{\n        \\\"model\\\": \\\"mlx-community/Kimi-K2-Thinking\\\",\n        \\\"stream\\\": true,\n        \\\"messages\\\": [{ \\\"role\\\": \\\"user\\\",   \\\"content\\\": \\\"$QUERY\\\"}]\n      }\" |\n  grep --line-buffered '^data:' |\n  grep --line-buffered -v 'data: \\[DONE\\]' |\n  cut -d' ' -f2- |\n  jq -r --unbuffered '.choices[].delta.content // empty' |\n  awk '{ORS=\"\"; print; fflush()} END {print \"\\n\"}'\n"
  },
  {
    "path": "tmp/set_rdma_network_config.sh",
    "content": "#!/usr/bin/env bash\n\nset -euo pipefail\n\nPREFS=\"/Library/Preferences/SystemConfiguration/preferences.plist\"\n\n# Remove bridge0 interface\nifconfig bridge0 &>/dev/null && {\n  ifconfig bridge0 | grep -q 'member' && {\n    ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true\n  }\n  ifconfig bridge0 destroy 2>/dev/null || true\n}\n\n# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist\n/usr/libexec/PlistBuddy -c \"Delete :VirtualNetworkInterfaces:Bridge:bridge0\" \"$PREFS\" 2>/dev/null || true\n\nnetworksetup -listlocations | grep -q exo || {\n  networksetup -createlocation exo\n}\n\nnetworksetup -switchtolocation exo\nnetworksetup -listallhardwareports |\n  awk -F': ' '/Hardware Port: / {print $2}' |\n  while IFS=\":\" read -r name; do\n    case \"$name\" in\n    \"Ethernet Adapter\"*) ;;\n    \"Thunderbolt Bridge\") ;;\n    \"Thunderbolt \"*)\n      networksetup -listallnetworkservices |\n        grep -q \"EXO $name\" ||\n        networksetup -createnetworkservice \"EXO $name\" \"$name\" 2>/dev/null ||\n        continue\n      networksetup -setdhcp \"EXO $name\"\n      ;;\n    *)\n      networksetup -listallnetworkservices |\n        grep -q \"$name\" ||\n        networksetup -createnetworkservice \"$name\" \"$name\" 2>/dev/null ||\n        continue\n      ;;\n    esac\n  done\n\nnetworksetup -listnetworkservices | grep -q \"Thunderbolt Bridge\" && {\n  networksetup -setnetworkserviceenabled \"Thunderbolt Bridge\" off\n} || true\n"
  },
  {
    "path": "tmp/test_trust_remote_code_attack.sh",
    "content": "#!/usr/bin/env bash\n# Test that models added via API get trust_remote_code=false\n# Run this against a running exo instance.\n# Usage: ./test_trust_remote_code_attack.sh [host:port]\n\nset -uo pipefail\n\nHOST=\"${1:-localhost:52415}\"\nMODEL_ID=\"KevTheHermit/security-testing\"\nCUSTOM_CARDS_DIR=\"$HOME/.exo/custom_model_cards\"\nCARD_FILE=\"$CUSTOM_CARDS_DIR/KevTheHermit--security-testing.toml\"\n\necho \"=== Test: trust_remote_code attack via API ===\"\necho \"Target: $HOST\"\necho \"\"\n\n# Clean up RCE proof from previous runs\nrm -f /tmp/exo-rce-proof.txt\n\n# Step 0: Clean up any stale card from previous runs\nif [ -f \"$CARD_FILE\" ]; then\n  echo \"[0] Removing stale card from previous run ...\"\n  curl -s -X DELETE \\\n    \"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote(\"'\"$MODEL_ID\"'\", safe=\"\"))')\" >/dev/null\n  rm -f \"$CARD_FILE\"\n  echo \"    Done\"\n  echo \"\"\nfi\n\n# Step 1: Add the malicious model via API\necho \"[1] Adding model via POST /models/add ...\"\nADD_RESPONSE=$(curl -s -w \"\\n%{http_code}\" -X POST \"http://$HOST/models/add\" \\\n  -H \"Content-Type: application/json\" \\\n  -d \"{\\\"model_id\\\":\\\"$MODEL_ID\\\"}\")\nHTTP_CODE=$(echo \"$ADD_RESPONSE\" | tail -1)\nBODY=$(echo \"$ADD_RESPONSE\" | sed '$d')\necho \"    HTTP $HTTP_CODE\"\n\nif [ \"$HTTP_CODE\" -ge 400 ]; then\n  echo \"    Model add failed (HTTP $HTTP_CODE) — that's fine if model doesn't exist on HF.\"\n  echo \"    Response: $BODY\"\n  echo \"\"\n  echo \"RESULT: Model was rejected at add time. Attack blocked.\"\n  exit 0\nfi\n\n# Step 2: Verify the saved TOML has trust_remote_code = false\necho \"\"\necho \"[2] Checking saved model card TOML ...\"\nif [ ! -f \"$CARD_FILE\" ]; then\n  echo \"    FAIL: Card file not found at $CARD_FILE\"\n  exit 1\nfi\n\nif grep -q 'trust_remote_code = false' \"$CARD_FILE\"; then\n  echo \"    SAFE: trust_remote_code = false (fix is active)\"\nelse\n  echo \"    VULNERABLE: trust_remote_code is not false — remote code WILL be trusted\"\nfi\necho \"    Contents:\"\ncat \"$CARD_FILE\"\n\n# Step 3: Place the instance\necho \"\"\necho \"[3] Attempting POST /place_instance ...\"\nPLACE_RESPONSE=$(curl -s -w \"\\n%{http_code}\" -X POST \"http://$HOST/place_instance\" \\\n  -H \"Content-Type: application/json\" \\\n  -d \"{\\\"model_id\\\":\\\"$MODEL_ID\\\"}\")\nPLACE_CODE=$(echo \"$PLACE_RESPONSE\" | tail -1)\nPLACE_BODY=$(echo \"$PLACE_RESPONSE\" | sed '$d')\necho \"    HTTP $PLACE_CODE\"\necho \"    Response: $PLACE_BODY\"\n\n# Step 3b: Send a chat completion to actually trigger tokenizer loading\necho \"\"\necho \"[3b] Sending chat completion to trigger tokenizer load ...\"\nCHAT_RESPONSE=$(curl -s -w \"\\n%{http_code}\" --max-time 30 -X POST \"http://$HOST/v1/chat/completions\" \\\n  -H \"Content-Type: application/json\" \\\n  -d \"{\\\"model\\\":\\\"$MODEL_ID\\\",\\\"messages\\\":[{\\\"role\\\":\\\"user\\\",\\\"content\\\":\\\"hello\\\"}],\\\"max_tokens\\\":1}\")\nCHAT_CODE=$(echo \"$CHAT_RESPONSE\" | tail -1)\nCHAT_BODY=$(echo \"$CHAT_RESPONSE\" | sed '$d')\necho \"    HTTP $CHAT_CODE\"\necho \"    Response: $CHAT_BODY\"\necho \"\"\necho \"[3c] Checking for RCE proof ...\"\nsleep 5\nif [ -f /tmp/exo-rce-proof.txt ]; then\n  echo \"    VULNERABLE: Remote code executed!\"\n  echo \"    Contents:\"\n  cat /tmp/exo-rce-proof.txt\nelse\n  echo \"    SAFE: /tmp/exo-rce-proof.txt does not exist — remote code was NOT executed\"\nfi\n\n# Step 4: Clean up — delete instance and custom model\necho \"\"\necho \"[4] Cleaning up ...\"\n\n# Find and delete any instance for this model\nINSTANCE_ID=$(curl -s \"http://$HOST/state\" | python3 -c \"\nimport sys, json\nstate = json.load(sys.stdin)\nfor iid, wrapper in state.get('instances', {}).items():\n    for tag, inst in wrapper.items():\n        sa = inst.get('shardAssignments', {})\n        if sa.get('modelId', '') == '$MODEL_ID':\n            print(iid)\n            sys.exit(0)\n\" 2>/dev/null || true)\n\nif [ -n \"$INSTANCE_ID\" ]; then\n  echo \"    Deleting instance $INSTANCE_ID ...\"\n  curl -s -X DELETE \"http://$HOST/instance/$INSTANCE_ID\" >/dev/null\n  echo \"    Done\"\nelse\n  echo \"    No instance found to delete\"\nfi\n\necho \"    Deleting custom model card ...\"\ncurl -s -X DELETE \\\n  \"http://$HOST/models/custom/$(python3 -c 'import urllib.parse; print(urllib.parse.quote(\"'\"$MODEL_ID\"'\", safe=\"\"))')\" >/dev/null\necho \"    Done\"\n\necho \"\"\necho \"=== DONE ===\"\n"
  }
]